1import os
  2import hashlib
  3import inspect
  4import zipfile
  5import requests
  6from tqdm import tqdm
  7from warnings import warn
  8from subprocess import run
  9from xml.dom import minidom
 10from packaging import version
 11from shutil import copyfileobj, which
 13from typing import Optional, Tuple, Literal
 15import numpy as np
 16from skimage.draw import polygon
 18import torch
 20import torch_em
 21from torch_em.transform import get_raw_transform
 22from torch_em.transform.generic import ResizeLongestSideInputs, Compose
 25    import gdown
 26except ImportError:
 27    gdown = None
 30    from tcia_utils import nbia
 31except ModuleNotFoundError:
 32    nbia = None
 35    from cryoet_data_portal import Client, Dataset
 36except ImportError:
 37    Client, Dataset = None, None
 40    import synapseclient
 41    import synapseutils
 42except ImportError:
 43    synapseclient, synapseutils = None, None
 47    "covid_if": "ilastik/covid_if_training_data",
 48    "cremi": "ilastik/cremi_training_data",
 49    "dsb": "ilastik/stardist_dsb_training_data",
 50    "hpa": "",  # not on bioimageio yet
 51    "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge",
 52    "kasthuri": "",  # not on bioimageio yet:
 53    "livecell": "ilastik/livecell_dataset",
 54    "lucchi": "",  # not on bioimageio yet:
 55    "mitoem": "ilastik/mitoem_segmentation_challenge",
 56    "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018",
 57    "ovules": "",  # not on bioimageio yet
 58    "plantseg_root": "ilastik/plantseg_root",
 59    "plantseg_ovules": "ilastik/plantseg_ovules",
 60    "platynereis": "ilastik/platynereis_em_training_data",
 61    "snemi": "",  # not on bioimagegio yet
 62    "uro_cell": "",  # not on bioimageio yet:
 63    "vnc": "ilastik/vnc",
 69def get_bioimageio_dataset_id(dataset_name):
 70    """@private
 71    """
 72    assert dataset_name in BIOIMAGEIO_IDS
 73    return BIOIMAGEIO_IDS[dataset_name]
 76def get_checksum(filename: str) -> str:
 77    """Get the SHA256 checksum of a file.
 79    Args:
 80        filename: The filepath.
 82    Returns:
 83        The checksum.
 84    """
 85    with open(filename, "rb") as f:
 86        file_ =
 87        checksum = hashlib.sha256(file_).hexdigest()
 88    return checksum
 91def _check_checksum(path, checksum):
 92    if checksum is not None:
 93        this_checksum = get_checksum(path)
 94        if this_checksum != checksum:
 95            raise RuntimeError(
 96                "The checksum of the download does not match the expected checksum."
 97                f"Expected: {checksum}, got: {this_checksum}"
 98            )
 99        print("Download successful and checksums agree.")
100    else:
101        warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.")
104# this needs to be extended to support download from s3 via boto,
105# if we get a resource that is available via s3 without support for http
106def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
107    """Download data via https.
109    Args:
110        path: The path for saving the data.
111        url: The url of the data.
112        download: Whether to download the data if it is not saved at `path` yet.
113        checksum: The expected checksum of the data.
114        verify: Whether to verify the https address.
115    """
116    if os.path.exists(path):
117        return
118    if not download:
119        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
121    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
122        r.raise_for_status()  # check for error
123        file_size = int(r.headers.get("Content-Length", 0))
124        desc = f"Download {url} to {path}"
125        if file_size == 0:
126            desc += " (unknown file size)"
127        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
128            copyfileobj(r_raw, f)
130    _check_checksum(path, checksum)
133def download_source_gdrive(
134    path: str,
135    url: str,
136    download: bool,
137    checksum: Optional[str] = None,
138    download_type: Literal["zip", "folder"] = "zip",
139    expected_samples: int = 10000,
140    quiet: bool = True,
141) -> None:
142    """Download data from google drive.
144    Args:
145        path: The path for saving the data.
146        url: The url of the data.
147        download: Whether to download the data if it is not saved at `path` yet.
148        checksum: The expected checksum of the data.
149        download_type: The download type, either 'zip' or 'folder'.
150        expected_samples: The maximal number of samples in the folder.
151        quiet: Whether to download quietly.
152    """
153    if os.path.exists(path):
154        return
156    if not download:
157        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
159    if gdown is None:
160        raise RuntimeError(
161            "Need gdown library to download data from google drive. "
162            "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
163        )
165    print("Downloading the files. Might take a few minutes...")
167    if download_type == "zip":
168, path, quiet=quiet)
169        _check_checksum(path, checksum)
170    elif download_type == "folder":
171        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'."
172        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
173        gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
174    else:
175        raise ValueError("`download_path` argument expects either `zip`/`folder`")
177    print("Download completed.")
180def download_source_empiar(path: str, access_id: str, download: bool) -> str:
181    """Download data from EMPIAR.
183    Requires the ascp command from the aspera CLI.
185    Args:
186        path: The path for saving the data.
187        access_id: The EMPIAR accession id of the data to download.
188        download: Whether to download the data if it is not saved at `path` yet.
190    Returns:
191        The path to the downloaded data.
192    """
193    download_path = os.path.join(path, access_id)
195    if os.path.exists(download_path):
196        return download_path
197    if not download:
198        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
200    if which("ascp") is None:
201        raise RuntimeError(
202            "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'."
203        )
205    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
206    if not os.path.exists(key_file):
207        conda_root = os.environ["CONDA_PREFIX"]
208        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
210    if not os.path.exists(key_file):
211        raise RuntimeError("Could not find the aspera ssh keyfile")
213    cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"{access_id}", path]
214    run(cmd)
216    return download_path
219def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False):
220    """Download data from Kaggle.
222    Requires the Kaggle API.
224    Args:
225        path: The path for saving the data.
226        dataset_name: The name of the dataset to download.
227        download: Whether to download the data if it is not saved at `path` yet.
228        competition: Whether this data is from a competition and requires the kaggle.competition API.
229    """
230    if not download:
231        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
233    try:
234        from kaggle.api.kaggle_api_extended import KaggleApi
235    except ModuleNotFoundError:
236        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
237        msg += "After you have installed kaggle, you would need an API token. "
238        msg += "Follow the instructions at"
239        raise ModuleNotFoundError(msg)
241    api = KaggleApi()
242    api.authenticate()
244    if competition:
245        api.competition_download_files(competition=dataset_name, path=path, quiet=False)
246    else:
247        api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
250def download_source_tcia(path, url, dst, csv_filename, download):
251    """Download data from TCIA.
253    Requires the tcia_utils python package.
255    Args:
256        path: The path for saving the data.
257        url: The URL to the TCIA dataset.
258        dst:
259        csv_filename:
260        download: Whether to download the data if it is not saved at `path` yet.
261    """
262    if nbia is None:
263        raise RuntimeError("Requires the tcia_utils python package.")
264    if not download:
265        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
266    assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest."
268    # Downloads the manifest file from the collection page.
269    manifest = requests.get(url=url)
270    with open(path, "wb") as f:
271        f.write(manifest.content)
273    # This part extracts the UIDs from the manifests and downloads them.
274    nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)
277def download_source_synapse(path: str, entity: str, download: bool) -> None:
278    """Download data from synapse.
280    Requires the synapseclient python library.
282    Args:
283        path: The path for saving the data.
284        entity: The name of the data to download from synapse.
285        download: Whether to download the data if it is not saved at `path` yet.
286    """
287    if not download:
288        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
290    if synapseclient is None:
291        raise RuntimeError(
292            "You must install 'synapseclient' to download files from 'synapse'. "
293            "Remember to create an account and generate an authentication code for your account. "
294            "Please follow the documentation for details on creating the '~/.synapseConfig' file here: "
295            ""
296        )
298    assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'."
300    # Download all files in the folder.
301    syn = synapseclient.Synapse()
302    syn.login()  # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'.
303    synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)
306def update_kwargs(kwargs, key, value, msg=None):
307    """@private
308    """
309    if key in kwargs:
310        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
311        warn(msg)
312    kwargs[key] = value
313    return kwargs
316def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
317    """Unpack a tar archive.
319    Args:
320        tar_path: Path to the tar file.
321        dst: Where to unpack the archive.
322        remove: Whether to remove the tar file after unpacking.
323    """
324    import tarfile
326    if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"):
327        access_mode = "r:gz"
328    elif tar_path.endswith(".tar"):
329        access_mode = "r:"
330    else:
331        raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.")
333    tar =, access_mode)
334    tar.extractall(dst)
335    tar.close()
337    if remove:
338        os.remove(tar_path)
341def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
342    """Unpack a rar archive.
344    Args:
345        rar_path: Path to the rar file.
346        dst: Where to unpack the archive.
347        remove: Whether to remove the tar file after unpacking.
348        use_rarfile: Whether to use the rarfile library or
349    """
350    if use_rarfile:
351        import rarfile
352        with rarfile.RarFile(rar_path) as f:
353            f.extractall(path=dst)
354    else:
355        import as az
356        with az.rar.RarArchive(rar_path) as archive:
357            archive.extract_to_directory(dst)
359    if remove:
360        os.remove(rar_path)
363def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
364    """Unpack a zip archive.
366    Args:
367        zip_path: Path to the zip file.
368        dst: Where to unpack the archive.
369        remove: Whether to remove the tar file after unpacking.
370    """
371    with zipfile.ZipFile(zip_path, "r") as f:
372        f.extractall(dst)
373    if remove:
374        os.remove(zip_path)
377def split_kwargs(function, **kwargs):
378    """@private
379    """
380    function_parameters = inspect.signature(function).parameters
381    parameter_names = list(function_parameters.keys())
382    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
383    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
384    return kwargs, other_kwargs
387# this adds the default transforms for 'raw_transform' and 'transform'
388# in case these were not specified in the kwargs
389# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class
390# is used directly, e.g. in the LiveCell Loader
391def ensure_transforms(ndim, **kwargs):
392    """@private
393    """
394    if "raw_transform" not in kwargs:
395        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
396    if "transform" not in kwargs:
397        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
398    return kwargs
401def add_instance_label_transform(
402    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
404    """@private
405    """
406    if binary_is_exclusive:
407        assert sum((offsets is not None, boundaries, binary)) <= 1
408    else:
409        assert sum((offsets is not None, boundaries)) <= 1
410    if offsets is not None:
411        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
412                                                                      add_binary_target=add_binary_target,
413                                                                      add_mask=True)
414        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
415        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
416        label_dtype = torch.float32
417    elif boundaries:
418        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
419        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
420        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
421        label_dtype = torch.float32
422    elif binary:
423        label_transform = torch_em.transform.label.labels_to_binary
424        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
425        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
426        label_dtype = torch.float32
427    return kwargs, label_dtype
430def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None):
431    """@private
432    """
433    # Checks for raw_transform and label_transform incoming values.
434    # If yes, it will automatically merge these two transforms to apply them together.
435    if resize_inputs:
436        assert isinstance(resize_kwargs, dict)
438        target_shape = resize_kwargs.get("patch_shape")
439        if len(resize_kwargs["patch_shape"]) == 3:
440            # we only need the XY dimensions to reshape the inputs along them.
441            target_shape = target_shape[1:]
442            # we provide the Z dimension value to return the desired number of slices and not the whole volume
443            kwargs["z_ext"] = resize_kwargs["patch_shape"][0]
445        raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"])
446        label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True)
448        # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded.
449        patch_shape = None
451    if ensure_rgb is None:
452        raw_trafos = []
453    else:
454        assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function."
455        raw_trafos = [ensure_rgb]
457    if "raw_transform" in kwargs:
458        raw_trafos.extend([raw_trafo, kwargs["raw_transform"]])
459    else:
460        raw_trafos.extend([raw_trafo, get_raw_transform()])
462    kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False)
464    if "label_transform" in kwargs:
465        trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False)
466        kwargs["label_transform"] = trafo
467    else:
468        kwargs["label_transform"] = label_trafo
470    return kwargs, patch_shape
473def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray:
474    """Generate a label mask from a contour defined in a xml annotation file.
476    Function taken from:
478    Args:
479        shape: The image shape.
480        xml_file: The path to the xml file with contour annotations.
482    Returns:
483        The label mask.
484    """
485    # DOM object created by the minidom parser
486    xDoc = minidom.parse(xml_file)
488    # List of all Region tags
489    regions = xDoc.getElementsByTagName('Region')
491    # List which will store the vertices for each region
492    xy = []
493    for region in regions:
494        # Loading all the vertices in the region
495        vertices = region.getElementsByTagName('Vertex')
497        # The vertices of a region will be stored in a array
498        vw = np.zeros((len(vertices), 2))
500        for index, vertex in enumerate(vertices):
501            # Storing the values of x and y coordinate after conversion
502            vw[index][0] = float(vertex.getAttribute('X'))
503            vw[index][1] = float(vertex.getAttribute('Y'))
505        # Append the vertices of a region
506        xy.append(np.int32(vw))
508    # Creating a completely black image
509    mask = np.zeros(shape, np.float32)
511    for i, contour in enumerate(xy):
512        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
513        mask[r, c] = i
514    return mask
517# This function could be extended to convert WSIs (or modalities with multiple resolutions).
518def convert_svs_to_array(
519    path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None,
520) -> np.ndarray:
521    """Convert a .svs file for WSI imagging to a numpy array.
523    Requires the tiffslide python library.
524    The function can load multi-resolution images. You can specify the resolution level via `level`.
526    Args:
527        path: File path ath to the svs file.
528        location: Pixel location (x, y) in level 0 of the image.
529        level: Target level used to read the image.
530        img_size: Size of the image. If None, the shape of the image at `level` is used.
532    Returns:
533        The image as numpy array.
534    """
535    from tiffslide import TiffSlide
537    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
538    _slide = TiffSlide(path)
539    if img_size is None:
540        img_size = _slide.level_dimensions[0]
541    return _slide.read_region(location=location, level=level, size=img_size, as_array=True)
544def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
545    """Download data from the CryoET Data Portal.
547    Requires the cryoet-data-portal python library.
549    Args:
550        path: The path for saving the data.
551        dataset_id: The id of the data to download from the portal.
552        download: Whether to download the data if it is not saved at `path` yet.
554    Returns:
555        The file path to the downloaded data.
556    """
557    if Client is None or Dataset is None:
558        raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'")
560    output_path = os.path.join(path, str(dataset_id))
561    if os.path.exists(output_path):
562        return output_path
564    if not download:
565        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
567    client = Client()
568    dataset = Dataset.get_by_id(client, dataset_id)
569    dataset.download_everything(dest_path=path)
571    return output_path
