torch_em.data.datasets.util

  1import os
  2import hashlib
  3import inspect
  4import zipfile
  5import requests
  6from tqdm import tqdm
  7from warnings import warn
  8from subprocess import run
  9from xml.dom import minidom
 10from packaging import version
 11from shutil import copyfileobj, which
 12
 13from typing import Optional, Tuple, Literal
 14
 15import numpy as np
 16from skimage.draw import polygon
 17
 18import torch
 19
 20import torch_em
 21from torch_em.transform import get_raw_transform
 22from torch_em.transform.generic import ResizeLongestSideInputs, Compose
 23
 24try:
 25    import gdown
 26except ImportError:
 27    gdown = None
 28
 29try:
 30    from tcia_utils import nbia
 31except ModuleNotFoundError:
 32    nbia = None
 33
 34try:
 35    from cryoet_data_portal import Client, Dataset
 36except ImportError:
 37    Client, Dataset = None, None
 38
 39try:
 40    import synapseclient
 41    import synapseutils
 42except ImportError:
 43    synapseclient, synapseutils = None, None
 44
 45
 46BIOIMAGEIO_IDS = {
 47    "covid_if": "ilastik/covid_if_training_data",
 48    "cremi": "ilastik/cremi_training_data",
 49    "dsb": "ilastik/stardist_dsb_training_data",
 50    "hpa": "",  # not on bioimageio yet
 51    "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge",
 52    "kasthuri": "",  # not on bioimageio yet:
 53    "livecell": "ilastik/livecell_dataset",
 54    "lucchi": "",  # not on bioimageio yet:
 55    "mitoem": "ilastik/mitoem_segmentation_challenge",
 56    "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018",
 57    "ovules": "",  # not on bioimageio yet
 58    "plantseg_root": "ilastik/plantseg_root",
 59    "plantseg_ovules": "ilastik/plantseg_ovules",
 60    "platynereis": "ilastik/platynereis_em_training_data",
 61    "snemi": "",  # not on bioimagegio yet
 62    "uro_cell": "",  # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693
 63    "vnc": "ilastik/vnc",
 64}
 65"""@private
 66"""
 67
 68
 69def get_bioimageio_dataset_id(dataset_name):
 70    """@private
 71    """
 72    assert dataset_name in BIOIMAGEIO_IDS
 73    return BIOIMAGEIO_IDS[dataset_name]
 74
 75
 76def get_checksum(filename: str) -> str:
 77    """Get the SHA256 checksum of a file.
 78
 79    Args:
 80        filename: The filepath.
 81
 82    Returns:
 83        The checksum.
 84    """
 85    with open(filename, "rb") as f:
 86        file_ = f.read()
 87        checksum = hashlib.sha256(file_).hexdigest()
 88    return checksum
 89
 90
 91def _check_checksum(path, checksum):
 92    if checksum is not None:
 93        this_checksum = get_checksum(path)
 94        if this_checksum != checksum:
 95            raise RuntimeError(
 96                "The checksum of the download does not match the expected checksum."
 97                f"Expected: {checksum}, got: {this_checksum}"
 98            )
 99        print("Download successful and checksums agree.")
100    else:
101        warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.")
102
103
104# this needs to be extended to support download from s3 via boto,
105# if we get a resource that is available via s3 without support for http
106def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
107    """Download data via https.
108
109    Args:
110        path: The path for saving the data.
111        url: The url of the data.
112        download: Whether to download the data if it is not saved at `path` yet.
113        checksum: The expected checksum of the data.
114        verify: Whether to verify the https address.
115    """
116    if os.path.exists(path):
117        return
118    if not download:
119        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
120
121    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
122        r.raise_for_status()  # check for error
123        file_size = int(r.headers.get("Content-Length", 0))
124        desc = f"Download {url} to {path}"
125        if file_size == 0:
126            desc += " (unknown file size)"
127        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
128            copyfileobj(r_raw, f)
129
130    _check_checksum(path, checksum)
131
132
133def download_source_gdrive(
134    path: str,
135    url: str,
136    download: bool,
137    checksum: Optional[str] = None,
138    download_type: Literal["zip", "folder"] = "zip",
139    expected_samples: int = 10000,
140    quiet: bool = True,
141) -> None:
142    """Download data from google drive.
143
144    Args:
145        path: The path for saving the data.
146        url: The url of the data.
147        download: Whether to download the data if it is not saved at `path` yet.
148        checksum: The expected checksum of the data.
149        download_type: The download type, either 'zip' or 'folder'.
150        expected_samples: The maximal number of samples in the folder.
151        quiet: Whether to download quietly.
152    """
153    if os.path.exists(path):
154        return
155
156    if not download:
157        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
158
159    if gdown is None:
160        raise RuntimeError(
161            "Need gdown library to download data from google drive. "
162            "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
163        )
164
165    print("Downloading the files. Might take a few minutes...")
166
167    if download_type == "zip":
168        gdown.download(url, path, quiet=quiet)
169        _check_checksum(path, checksum)
170    elif download_type == "folder":
171        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'."
172        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
173        gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
174    else:
175        raise ValueError("`download_path` argument expects either `zip`/`folder`")
176
177    print("Download completed.")
178
179
180def download_source_empiar(path: str, access_id: str, download: bool) -> str:
181    """Download data from EMPIAR.
182
183    Requires the ascp command from the aspera CLI.
184
185    Args:
186        path: The path for saving the data.
187        access_id: The EMPIAR accession id of the data to download.
188        download: Whether to download the data if it is not saved at `path` yet.
189
190    Returns:
191        The path to the downloaded data.
192    """
193    download_path = os.path.join(path, access_id)
194
195    if os.path.exists(download_path):
196        return download_path
197    if not download:
198        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
199
200    if which("ascp") is None:
201        raise RuntimeError(
202            "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'."
203        )
204
205    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
206    if not os.path.exists(key_file):
207        conda_root = os.environ["CONDA_PREFIX"]
208        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
209
210    if not os.path.exists(key_file):
211        raise RuntimeError("Could not find the aspera ssh keyfile")
212
213    cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path]
214    run(cmd)
215
216    return download_path
217
218
219def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False):
220    """Download data from Kaggle.
221
222    Requires the Kaggle API.
223
224    Args:
225        path: The path for saving the data.
226        dataset_name: The name of the dataset to download.
227        download: Whether to download the data if it is not saved at `path` yet.
228        competition: Whether this data is from a competition and requires the kaggle.competition API.
229    """
230    if not download:
231        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
232
233    try:
234        from kaggle.api.kaggle_api_extended import KaggleApi
235    except ModuleNotFoundError:
236        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
237        msg += "After you have installed kaggle, you would need an API token. "
238        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
239        raise ModuleNotFoundError(msg)
240
241    api = KaggleApi()
242    api.authenticate()
243
244    if competition:
245        api.competition_download_files(competition=dataset_name, path=path, quiet=False)
246    else:
247        api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
248
249
250def download_source_tcia(path, url, dst, csv_filename, download):
251    """Download data from TCIA.
252
253    Requires the tcia_utils python package.
254
255    Args:
256        path: The path for saving the data.
257        url: The URL to the TCIA dataset.
258        dst:
259        csv_filename:
260        download: Whether to download the data if it is not saved at `path` yet.
261    """
262    if nbia is None:
263        raise RuntimeError("Requires the tcia_utils python package.")
264    if not download:
265        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
266    assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest."
267
268    # Downloads the manifest file from the collection page.
269    manifest = requests.get(url=url)
270    with open(path, "wb") as f:
271        f.write(manifest.content)
272
273    # This part extracts the UIDs from the manifests and downloads them.
274    nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)
275
276
277def download_source_synapse(path: str, entity: str, download: bool) -> None:
278    """Download data from synapse.
279
280    Requires the synapseclient python library.
281
282    Args:
283        path: The path for saving the data.
284        entity: The name of the data to download from synapse.
285        download: Whether to download the data if it is not saved at `path` yet.
286    """
287    if not download:
288        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
289
290    if synapseclient is None:
291        raise RuntimeError(
292            "You must install 'synapseclient' to download files from 'synapse'. "
293            "Remember to create an account and generate an authentication code for your account. "
294            "Please follow the documentation for details on creating the '~/.synapseConfig' file here: "
295            "https://python-docs.synapse.org/tutorials/authentication/."
296        )
297
298    assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'."
299
300    # Download all files in the folder.
301    syn = synapseclient.Synapse()
302    syn.login()  # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'.
303    synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)
304
305
306def update_kwargs(kwargs, key, value, msg=None):
307    """@private
308    """
309    if key in kwargs:
310        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
311        warn(msg)
312    kwargs[key] = value
313    return kwargs
314
315
316def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
317    """Unpack a tar archive.
318
319    Args:
320        tar_path: Path to the tar file.
321        dst: Where to unpack the archive.
322        remove: Whether to remove the tar file after unpacking.
323    """
324    import tarfile
325
326    if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"):
327        access_mode = "r:gz"
328    elif tar_path.endswith(".tar"):
329        access_mode = "r:"
330    else:
331        raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.")
332
333    tar = tarfile.open(tar_path, access_mode)
334    tar.extractall(dst)
335    tar.close()
336
337    if remove:
338        os.remove(tar_path)
339
340
341def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
342    """Unpack a rar archive.
343
344    Args:
345        rar_path: Path to the rar file.
346        dst: Where to unpack the archive.
347        remove: Whether to remove the tar file after unpacking.
348        use_rarfile: Whether to use the rarfile library or aspose.zip.
349    """
350    def _extract_with_rarfile():
351        import rarfile
352        with rarfile.RarFile(rar_path) as archive:
353            archive.extractall(path=dst)
354
355    def _extract_with_aspose():
356        import aspose.zip as az
357        with az.rar.RarArchive(rar_path) as archive:
358            archive.extract_to_directory(dst)
359
360    extractors = [
361        ('rarfile', _extract_with_rarfile), ('aspose.zip', _extract_with_aspose),
362    ] if use_rarfile else [('aspose.zip', _extract_with_aspose)]
363
364    errors = []
365    for name, extractor in extractors:
366        try:
367            extractor()
368            break
369        except Exception as err:
370            errors.append((name, err))
371            if len(errors) < len(extractors):
372                next_name = extractors[len(errors)][0]
373                warn(f"Extraction with '{name}' failed for {rar_path} ({err}). Falling back to '{next_name}'.")
374    else:
375        backends = ', '.join(f"'{name}'" for name, _ in extractors)
376        raise RuntimeError(
377            f"Failed to extract rar archive {rar_path} with {backends}. "
378            "Please ensure one of the supported backends is installed and can read this archive."
379        ) from errors[-1][1]
380
381    if remove:
382        os.remove(rar_path)
383
384
385def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
386    """Unpack a zip archive.
387
388    Args:
389        zip_path: Path to the zip file.
390        dst: Where to unpack the archive.
391        remove: Whether to remove the tar file after unpacking.
392    """
393    with zipfile.ZipFile(zip_path, "r") as f:
394        f.extractall(dst)
395    if remove:
396        os.remove(zip_path)
397
398
399def split_kwargs(function, **kwargs):
400    """@private
401    """
402    function_parameters = inspect.signature(function).parameters
403    parameter_names = list(function_parameters.keys())
404    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
405    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
406    return kwargs, other_kwargs
407
408
409# this adds the default transforms for 'raw_transform' and 'transform'
410# in case these were not specified in the kwargs
411# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class
412# is used directly, e.g. in the LiveCell Loader
413def ensure_transforms(ndim, **kwargs):
414    """@private
415    """
416    if "raw_transform" not in kwargs:
417        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
418    if "transform" not in kwargs:
419        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
420    return kwargs
421
422
423def add_instance_label_transform(
424    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
425):
426    """@private
427    """
428    if binary_is_exclusive:
429        assert sum((offsets is not None, boundaries, binary)) <= 1
430    else:
431        assert sum((offsets is not None, boundaries)) <= 1
432    if offsets is not None:
433        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
434                                                                      add_binary_target=add_binary_target,
435                                                                      add_mask=True)
436        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
437        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
438        label_dtype = torch.float32
439    elif boundaries:
440        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
441        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
442        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
443        label_dtype = torch.float32
444    elif binary:
445        label_transform = torch_em.transform.label.labels_to_binary
446        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
447        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
448        label_dtype = torch.float32
449    return kwargs, label_dtype
450
451
452def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None):
453    """@private
454    """
455    # Checks for raw_transform and label_transform incoming values.
456    # If yes, it will automatically merge these two transforms to apply them together.
457    if resize_inputs:
458        assert isinstance(resize_kwargs, dict)
459
460        target_shape = resize_kwargs.get("patch_shape")
461        if len(resize_kwargs["patch_shape"]) == 3:
462            # we only need the XY dimensions to reshape the inputs along them.
463            target_shape = target_shape[1:]
464            # we provide the Z dimension value to return the desired number of slices and not the whole volume
465            kwargs["z_ext"] = resize_kwargs["patch_shape"][0]
466
467        raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"])
468        label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True)
469
470        # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded.
471        patch_shape = None
472
473    if ensure_rgb is None:
474        raw_trafos = []
475    else:
476        assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function."
477        raw_trafos = [ensure_rgb]
478
479    if "raw_transform" in kwargs:
480        raw_trafos.extend([raw_trafo, kwargs["raw_transform"]])
481    else:
482        raw_trafos.extend([raw_trafo, get_raw_transform()])
483
484    kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False)
485
486    if "label_transform" in kwargs:
487        trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False)
488        kwargs["label_transform"] = trafo
489    else:
490        kwargs["label_transform"] = label_trafo
491
492    return kwargs, patch_shape
493
494
495def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray:
496    """Generate a label mask from a contour defined in a xml annotation file.
497
498    Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
499
500    Args:
501        shape: The image shape.
502        xml_file: The path to the xml file with contour annotations.
503
504    Returns:
505        The label mask.
506    """
507    # DOM object created by the minidom parser
508    xDoc = minidom.parse(xml_file)
509
510    # List of all Region tags
511    regions = xDoc.getElementsByTagName('Region')
512
513    # List which will store the vertices for each region
514    xy = []
515    for region in regions:
516        # Loading all the vertices in the region
517        vertices = region.getElementsByTagName('Vertex')
518
519        # The vertices of a region will be stored in a array
520        vw = np.zeros((len(vertices), 2))
521
522        for index, vertex in enumerate(vertices):
523            # Storing the values of x and y coordinate after conversion
524            vw[index][0] = float(vertex.getAttribute('X'))
525            vw[index][1] = float(vertex.getAttribute('Y'))
526
527        # Append the vertices of a region
528        xy.append(np.int32(vw))
529
530    # Creating a completely black image
531    mask = np.zeros(shape, np.float32)
532
533    for i, contour in enumerate(xy):
534        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
535        mask[r, c] = i
536    return mask
537
538
539# This function could be extended to convert WSIs (or modalities with multiple resolutions).
540def convert_svs_to_array(
541    path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None,
542) -> np.ndarray:
543    """Convert a .svs file for WSI imagging to a numpy array.
544
545    Requires the tiffslide python library.
546    The function can load multi-resolution images. You can specify the resolution level via `level`.
547
548    Args:
549        path: File path ath to the svs file.
550        location: Pixel location (x, y) in level 0 of the image.
551        level: Target level used to read the image.
552        img_size: Size of the image. If None, the shape of the image at `level` is used.
553
554    Returns:
555        The image as numpy array.
556    """
557    from tiffslide import TiffSlide
558
559    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
560    _slide = TiffSlide(path)
561    if img_size is None:
562        img_size = _slide.level_dimensions[0]
563    return _slide.read_region(location=location, level=level, size=img_size, as_array=True)
564
565
566def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
567    """Download data from the CryoET Data Portal.
568
569    Requires the cryoet-data-portal python library.
570
571    Args:
572        path: The path for saving the data.
573        dataset_id: The id of the data to download from the portal.
574        download: Whether to download the data if it is not saved at `path` yet.
575
576    Returns:
577        The file path to the downloaded data.
578    """
579    if Client is None or Dataset is None:
580        raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'")
581
582    output_path = os.path.join(path, str(dataset_id))
583    if os.path.exists(output_path):
584        return output_path
585
586    if not download:
587        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
588
589    client = Client()
590    dataset = Dataset.get_by_id(client, dataset_id)
591    dataset.download_everything(dest_path=path)
592
593    return output_path
def get_checksum(filename: str) -> str:
77def get_checksum(filename: str) -> str:
78    """Get the SHA256 checksum of a file.
79
80    Args:
81        filename: The filepath.
82
83    Returns:
84        The checksum.
85    """
86    with open(filename, "rb") as f:
87        file_ = f.read()
88        checksum = hashlib.sha256(file_).hexdigest()
89    return checksum

Get the SHA256 checksum of a file.

Arguments:
  • filename: The filepath.
Returns:

The checksum.

def download_source( path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
107def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
108    """Download data via https.
109
110    Args:
111        path: The path for saving the data.
112        url: The url of the data.
113        download: Whether to download the data if it is not saved at `path` yet.
114        checksum: The expected checksum of the data.
115        verify: Whether to verify the https address.
116    """
117    if os.path.exists(path):
118        return
119    if not download:
120        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
121
122    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
123        r.raise_for_status()  # check for error
124        file_size = int(r.headers.get("Content-Length", 0))
125        desc = f"Download {url} to {path}"
126        if file_size == 0:
127            desc += " (unknown file size)"
128        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
129            copyfileobj(r_raw, f)
130
131    _check_checksum(path, checksum)

Download data via https.

Arguments:
  • path: The path for saving the data.
  • url: The url of the data.
  • download: Whether to download the data if it is not saved at path yet.
  • checksum: The expected checksum of the data.
  • verify: Whether to verify the https address.
def download_source_gdrive( path: str, url: str, download: bool, checksum: Optional[str] = None, download_type: Literal['zip', 'folder'] = 'zip', expected_samples: int = 10000, quiet: bool = True) -> None:
134def download_source_gdrive(
135    path: str,
136    url: str,
137    download: bool,
138    checksum: Optional[str] = None,
139    download_type: Literal["zip", "folder"] = "zip",
140    expected_samples: int = 10000,
141    quiet: bool = True,
142) -> None:
143    """Download data from google drive.
144
145    Args:
146        path: The path for saving the data.
147        url: The url of the data.
148        download: Whether to download the data if it is not saved at `path` yet.
149        checksum: The expected checksum of the data.
150        download_type: The download type, either 'zip' or 'folder'.
151        expected_samples: The maximal number of samples in the folder.
152        quiet: Whether to download quietly.
153    """
154    if os.path.exists(path):
155        return
156
157    if not download:
158        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
159
160    if gdown is None:
161        raise RuntimeError(
162            "Need gdown library to download data from google drive. "
163            "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
164        )
165
166    print("Downloading the files. Might take a few minutes...")
167
168    if download_type == "zip":
169        gdown.download(url, path, quiet=quiet)
170        _check_checksum(path, checksum)
171    elif download_type == "folder":
172        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'."
173        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
174        gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
175    else:
176        raise ValueError("`download_path` argument expects either `zip`/`folder`")
177
178    print("Download completed.")

Download data from google drive.

Arguments:
  • path: The path for saving the data.
  • url: The url of the data.
  • download: Whether to download the data if it is not saved at path yet.
  • checksum: The expected checksum of the data.
  • download_type: The download type, either 'zip' or 'folder'.
  • expected_samples: The maximal number of samples in the folder.
  • quiet: Whether to download quietly.
def download_source_empiar(path: str, access_id: str, download: bool) -> str:
181def download_source_empiar(path: str, access_id: str, download: bool) -> str:
182    """Download data from EMPIAR.
183
184    Requires the ascp command from the aspera CLI.
185
186    Args:
187        path: The path for saving the data.
188        access_id: The EMPIAR accession id of the data to download.
189        download: Whether to download the data if it is not saved at `path` yet.
190
191    Returns:
192        The path to the downloaded data.
193    """
194    download_path = os.path.join(path, access_id)
195
196    if os.path.exists(download_path):
197        return download_path
198    if not download:
199        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
200
201    if which("ascp") is None:
202        raise RuntimeError(
203            "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'."
204        )
205
206    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
207    if not os.path.exists(key_file):
208        conda_root = os.environ["CONDA_PREFIX"]
209        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
210
211    if not os.path.exists(key_file):
212        raise RuntimeError("Could not find the aspera ssh keyfile")
213
214    cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path]
215    run(cmd)
216
217    return download_path

Download data from EMPIAR.

Requires the ascp command from the aspera CLI.

Arguments:
  • path: The path for saving the data.
  • access_id: The EMPIAR accession id of the data to download.
  • download: Whether to download the data if it is not saved at path yet.
Returns:

The path to the downloaded data.

def download_source_kaggle( path: str, dataset_name: str, download: bool, competition: bool = False):
220def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False):
221    """Download data from Kaggle.
222
223    Requires the Kaggle API.
224
225    Args:
226        path: The path for saving the data.
227        dataset_name: The name of the dataset to download.
228        download: Whether to download the data if it is not saved at `path` yet.
229        competition: Whether this data is from a competition and requires the kaggle.competition API.
230    """
231    if not download:
232        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
233
234    try:
235        from kaggle.api.kaggle_api_extended import KaggleApi
236    except ModuleNotFoundError:
237        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
238        msg += "After you have installed kaggle, you would need an API token. "
239        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
240        raise ModuleNotFoundError(msg)
241
242    api = KaggleApi()
243    api.authenticate()
244
245    if competition:
246        api.competition_download_files(competition=dataset_name, path=path, quiet=False)
247    else:
248        api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)

Download data from Kaggle.

Requires the Kaggle API.

Arguments:
  • path: The path for saving the data.
  • dataset_name: The name of the dataset to download.
  • download: Whether to download the data if it is not saved at path yet.
  • competition: Whether this data is from a competition and requires the kaggle.competition API.
def download_source_tcia(path, url, dst, csv_filename, download):
251def download_source_tcia(path, url, dst, csv_filename, download):
252    """Download data from TCIA.
253
254    Requires the tcia_utils python package.
255
256    Args:
257        path: The path for saving the data.
258        url: The URL to the TCIA dataset.
259        dst:
260        csv_filename:
261        download: Whether to download the data if it is not saved at `path` yet.
262    """
263    if nbia is None:
264        raise RuntimeError("Requires the tcia_utils python package.")
265    if not download:
266        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
267    assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest."
268
269    # Downloads the manifest file from the collection page.
270    manifest = requests.get(url=url)
271    with open(path, "wb") as f:
272        f.write(manifest.content)
273
274    # This part extracts the UIDs from the manifests and downloads them.
275    nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)

Download data from TCIA.

Requires the tcia_utils python package.

Arguments:
  • path: The path for saving the data.
  • url: The URL to the TCIA dataset.
  • dst:
  • csv_filename:
  • download: Whether to download the data if it is not saved at path yet.
def download_source_synapse(path: str, entity: str, download: bool) -> None:
278def download_source_synapse(path: str, entity: str, download: bool) -> None:
279    """Download data from synapse.
280
281    Requires the synapseclient python library.
282
283    Args:
284        path: The path for saving the data.
285        entity: The name of the data to download from synapse.
286        download: Whether to download the data if it is not saved at `path` yet.
287    """
288    if not download:
289        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
290
291    if synapseclient is None:
292        raise RuntimeError(
293            "You must install 'synapseclient' to download files from 'synapse'. "
294            "Remember to create an account and generate an authentication code for your account. "
295            "Please follow the documentation for details on creating the '~/.synapseConfig' file here: "
296            "https://python-docs.synapse.org/tutorials/authentication/."
297        )
298
299    assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'."
300
301    # Download all files in the folder.
302    syn = synapseclient.Synapse()
303    syn.login()  # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'.
304    synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)

Download data from synapse.

Requires the synapseclient python library.

Arguments:
  • path: The path for saving the data.
  • entity: The name of the data to download from synapse.
  • download: Whether to download the data if it is not saved at path yet.
def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
317def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
318    """Unpack a tar archive.
319
320    Args:
321        tar_path: Path to the tar file.
322        dst: Where to unpack the archive.
323        remove: Whether to remove the tar file after unpacking.
324    """
325    import tarfile
326
327    if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"):
328        access_mode = "r:gz"
329    elif tar_path.endswith(".tar"):
330        access_mode = "r:"
331    else:
332        raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.")
333
334    tar = tarfile.open(tar_path, access_mode)
335    tar.extractall(dst)
336    tar.close()
337
338    if remove:
339        os.remove(tar_path)

Unpack a tar archive.

Arguments:
  • tar_path: Path to the tar file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
def unzip_rarfile( rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
342def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
343    """Unpack a rar archive.
344
345    Args:
346        rar_path: Path to the rar file.
347        dst: Where to unpack the archive.
348        remove: Whether to remove the tar file after unpacking.
349        use_rarfile: Whether to use the rarfile library or aspose.zip.
350    """
351    def _extract_with_rarfile():
352        import rarfile
353        with rarfile.RarFile(rar_path) as archive:
354            archive.extractall(path=dst)
355
356    def _extract_with_aspose():
357        import aspose.zip as az
358        with az.rar.RarArchive(rar_path) as archive:
359            archive.extract_to_directory(dst)
360
361    extractors = [
362        ('rarfile', _extract_with_rarfile), ('aspose.zip', _extract_with_aspose),
363    ] if use_rarfile else [('aspose.zip', _extract_with_aspose)]
364
365    errors = []
366    for name, extractor in extractors:
367        try:
368            extractor()
369            break
370        except Exception as err:
371            errors.append((name, err))
372            if len(errors) < len(extractors):
373                next_name = extractors[len(errors)][0]
374                warn(f"Extraction with '{name}' failed for {rar_path} ({err}). Falling back to '{next_name}'.")
375    else:
376        backends = ', '.join(f"'{name}'" for name, _ in extractors)
377        raise RuntimeError(
378            f"Failed to extract rar archive {rar_path} with {backends}. "
379            "Please ensure one of the supported backends is installed and can read this archive."
380        ) from errors[-1][1]
381
382    if remove:
383        os.remove(rar_path)

Unpack a rar archive.

Arguments:
  • rar_path: Path to the rar file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
  • use_rarfile: Whether to use the rarfile library or aspose.zip.
def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
386def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
387    """Unpack a zip archive.
388
389    Args:
390        zip_path: Path to the zip file.
391        dst: Where to unpack the archive.
392        remove: Whether to remove the tar file after unpacking.
393    """
394    with zipfile.ZipFile(zip_path, "r") as f:
395        f.extractall(dst)
396    if remove:
397        os.remove(zip_path)

Unpack a zip archive.

Arguments:
  • zip_path: Path to the zip file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> numpy.ndarray:
496def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray:
497    """Generate a label mask from a contour defined in a xml annotation file.
498
499    Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
500
501    Args:
502        shape: The image shape.
503        xml_file: The path to the xml file with contour annotations.
504
505    Returns:
506        The label mask.
507    """
508    # DOM object created by the minidom parser
509    xDoc = minidom.parse(xml_file)
510
511    # List of all Region tags
512    regions = xDoc.getElementsByTagName('Region')
513
514    # List which will store the vertices for each region
515    xy = []
516    for region in regions:
517        # Loading all the vertices in the region
518        vertices = region.getElementsByTagName('Vertex')
519
520        # The vertices of a region will be stored in a array
521        vw = np.zeros((len(vertices), 2))
522
523        for index, vertex in enumerate(vertices):
524            # Storing the values of x and y coordinate after conversion
525            vw[index][0] = float(vertex.getAttribute('X'))
526            vw[index][1] = float(vertex.getAttribute('Y'))
527
528        # Append the vertices of a region
529        xy.append(np.int32(vw))
530
531    # Creating a completely black image
532    mask = np.zeros(shape, np.float32)
533
534    for i, contour in enumerate(xy):
535        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
536        mask[r, c] = i
537    return mask

Generate a label mask from a contour defined in a xml annotation file.

Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb

Arguments:
  • shape: The image shape.
  • xml_file: The path to the xml file with contour annotations.
Returns:

The label mask.

def convert_svs_to_array( path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None) -> numpy.ndarray:
541def convert_svs_to_array(
542    path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None,
543) -> np.ndarray:
544    """Convert a .svs file for WSI imagging to a numpy array.
545
546    Requires the tiffslide python library.
547    The function can load multi-resolution images. You can specify the resolution level via `level`.
548
549    Args:
550        path: File path ath to the svs file.
551        location: Pixel location (x, y) in level 0 of the image.
552        level: Target level used to read the image.
553        img_size: Size of the image. If None, the shape of the image at `level` is used.
554
555    Returns:
556        The image as numpy array.
557    """
558    from tiffslide import TiffSlide
559
560    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
561    _slide = TiffSlide(path)
562    if img_size is None:
563        img_size = _slide.level_dimensions[0]
564    return _slide.read_region(location=location, level=level, size=img_size, as_array=True)

Convert a .svs file for WSI imagging to a numpy array.

Requires the tiffslide python library. The function can load multi-resolution images. You can specify the resolution level via level.

Arguments:
  • path: File path ath to the svs file.
  • location: Pixel location (x, y) in level 0 of the image.
  • level: Target level used to read the image.
  • img_size: Size of the image. If None, the shape of the image at level is used.
Returns:

The image as numpy array.

def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
567def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
568    """Download data from the CryoET Data Portal.
569
570    Requires the cryoet-data-portal python library.
571
572    Args:
573        path: The path for saving the data.
574        dataset_id: The id of the data to download from the portal.
575        download: Whether to download the data if it is not saved at `path` yet.
576
577    Returns:
578        The file path to the downloaded data.
579    """
580    if Client is None or Dataset is None:
581        raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'")
582
583    output_path = os.path.join(path, str(dataset_id))
584    if os.path.exists(output_path):
585        return output_path
586
587    if not download:
588        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
589
590    client = Client()
591    dataset = Dataset.get_by_id(client, dataset_id)
592    dataset.download_everything(dest_path=path)
593
594    return output_path

Download data from the CryoET Data Portal.

Requires the cryoet-data-portal python library.

Arguments:
  • path: The path for saving the data.
  • dataset_id: The id of the data to download from the portal.
  • download: Whether to download the data if it is not saved at path yet.
Returns:

The file path to the downloaded data.