torch_em.data.datasets.util

  1import os
  2import hashlib
  3import inspect
  4import zipfile
  5import requests
  6from tqdm import tqdm
  7from warnings import warn
  8from subprocess import run
  9from xml.dom import minidom
 10from packaging import version
 11from shutil import copyfileobj, which
 12
 13from typing import Optional, Tuple, Literal
 14
 15import numpy as np
 16from skimage.draw import polygon
 17
 18import torch
 19
 20import torch_em
 21from torch_em.transform import get_raw_transform
 22from torch_em.transform.generic import ResizeLongestSideInputs, Compose
 23
 24try:
 25    import gdown
 26except ImportError:
 27    gdown = None
 28
 29try:
 30    from tcia_utils import nbia
 31except ModuleNotFoundError:
 32    nbia = None
 33
 34try:
 35    from cryoet_data_portal import Client, Dataset
 36except ImportError:
 37    Client, Dataset = None, None
 38
 39try:
 40    import synapseclient
 41    import synapseutils
 42except ImportError:
 43    synapseclient, synapseutils = None, None
 44
 45
 46BIOIMAGEIO_IDS = {
 47    "covid_if": "ilastik/covid_if_training_data",
 48    "cremi": "ilastik/cremi_training_data",
 49    "dsb": "ilastik/stardist_dsb_training_data",
 50    "hpa": "",  # not on bioimageio yet
 51    "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge",
 52    "kasthuri": "",  # not on bioimageio yet:
 53    "livecell": "ilastik/livecell_dataset",
 54    "lucchi": "",  # not on bioimageio yet:
 55    "mitoem": "ilastik/mitoem_segmentation_challenge",
 56    "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018",
 57    "ovules": "",  # not on bioimageio yet
 58    "plantseg_root": "ilastik/plantseg_root",
 59    "plantseg_ovules": "ilastik/plantseg_ovules",
 60    "platynereis": "ilastik/platynereis_em_training_data",
 61    "snemi": "",  # not on bioimagegio yet
 62    "uro_cell": "",  # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693
 63    "vnc": "ilastik/vnc",
 64}
 65"""@private
 66"""
 67
 68
 69def get_bioimageio_dataset_id(dataset_name):
 70    """@private
 71    """
 72    assert dataset_name in BIOIMAGEIO_IDS
 73    return BIOIMAGEIO_IDS[dataset_name]
 74
 75
 76def get_checksum(filename: str) -> str:
 77    """Get the SHA256 checksum of a file.
 78
 79    Args:
 80        filename: The filepath.
 81
 82    Returns:
 83        The checksum.
 84    """
 85    with open(filename, "rb") as f:
 86        file_ = f.read()
 87        checksum = hashlib.sha256(file_).hexdigest()
 88    return checksum
 89
 90
 91def _check_checksum(path, checksum):
 92    if checksum is not None:
 93        this_checksum = get_checksum(path)
 94        if this_checksum != checksum:
 95            raise RuntimeError(
 96                "The checksum of the download does not match the expected checksum."
 97                f"Expected: {checksum}, got: {this_checksum}"
 98            )
 99        print("Download successful and checksums agree.")
100    else:
101        warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.")
102
103
104# this needs to be extended to support download from s3 via boto,
105# if we get a resource that is available via s3 without support for http
106def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
107    """Download data via https.
108
109    Args:
110        path: The path for saving the data.
111        url: The url of the data.
112        download: Whether to download the data if it is not saved at `path` yet.
113        checksum: The expected checksum of the data.
114        verify: Whether to verify the https address.
115    """
116    if os.path.exists(path):
117        return
118    if not download:
119        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
120
121    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
122        r.raise_for_status()  # check for error
123        file_size = int(r.headers.get("Content-Length", 0))
124        desc = f"Download {url} to {path}"
125        if file_size == 0:
126            desc += " (unknown file size)"
127        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
128            copyfileobj(r_raw, f)
129
130    _check_checksum(path, checksum)
131
132
133def download_source_gdrive(
134    path: str,
135    url: str,
136    download: bool,
137    checksum: Optional[str] = None,
138    download_type: Literal["zip", "folder"] = "zip",
139    expected_samples: int = 10000,
140    quiet: bool = True,
141) -> None:
142    """Download data from google drive.
143
144    Args:
145        path: The path for saving the data.
146        url: The url of the data.
147        download: Whether to download the data if it is not saved at `path` yet.
148        checksum: The expected checksum of the data.
149        download_type: The download type, either 'zip' or 'folder'.
150        expected_samples: The maximal number of samples in the folder.
151        quiet: Whether to download quietly.
152    """
153    if os.path.exists(path):
154        return
155
156    if not download:
157        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
158
159    if gdown is None:
160        raise RuntimeError(
161            "Need gdown library to download data from google drive. "
162            "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
163        )
164
165    print("Downloading the files. Might take a few minutes...")
166
167    if download_type == "zip":
168        gdown.download(url, path, quiet=quiet)
169        _check_checksum(path, checksum)
170    elif download_type == "folder":
171        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'."
172        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
173        gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
174    else:
175        raise ValueError("`download_path` argument expects either `zip`/`folder`")
176
177    print("Download completed.")
178
179
180def download_source_empiar(path: str, access_id: str, download: bool) -> str:
181    """Download data from EMPIAR.
182
183    Requires the ascp command from the aspera CLI.
184
185    Args:
186        path: The path for saving the data.
187        access_id: The EMPIAR accession id of the data to download.
188        download: Whether to download the data if it is not saved at `path` yet.
189
190    Returns:
191        The path to the downloaded data.
192    """
193    download_path = os.path.join(path, access_id)
194
195    if os.path.exists(download_path):
196        return download_path
197    if not download:
198        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
199
200    if which("ascp") is None:
201        raise RuntimeError(
202            "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'."
203        )
204
205    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
206    if not os.path.exists(key_file):
207        conda_root = os.environ["CONDA_PREFIX"]
208        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
209
210    if not os.path.exists(key_file):
211        raise RuntimeError("Could not find the aspera ssh keyfile")
212
213    cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path]
214    run(cmd)
215
216    return download_path
217
218
219def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False):
220    """Download data from Kaggle.
221
222    Requires the Kaggle API.
223
224    Args:
225        path: The path for saving the data.
226        dataset_name: The name of the dataset to download.
227        download: Whether to download the data if it is not saved at `path` yet.
228        competition: Whether this data is from a competition and requires the kaggle.competition API.
229    """
230    if not download:
231        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
232
233    try:
234        from kaggle.api.kaggle_api_extended import KaggleApi
235    except ModuleNotFoundError:
236        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
237        msg += "After you have installed kaggle, you would need an API token. "
238        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
239        raise ModuleNotFoundError(msg)
240
241    api = KaggleApi()
242    api.authenticate()
243
244    if competition:
245        api.competition_download_files(competition=dataset_name, path=path, quiet=False)
246    else:
247        api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
248
249
250def download_source_tcia(path, url, dst, csv_filename, download):
251    """Download data from TCIA.
252
253    Requires the tcia_utils python package.
254
255    Args:
256        path: The path for saving the data.
257        url: The URL to the TCIA dataset.
258        dst:
259        csv_filename:
260        download: Whether to download the data if it is not saved at `path` yet.
261    """
262    if nbia is None:
263        raise RuntimeError("Requires the tcia_utils python package.")
264    if not download:
265        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
266    assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest."
267
268    # Downloads the manifest file from the collection page.
269    manifest = requests.get(url=url)
270    with open(path, "wb") as f:
271        f.write(manifest.content)
272
273    # This part extracts the UIDs from the manifests and downloads them.
274    nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)
275
276
277def download_source_synapse(path: str, entity: str, download: bool) -> None:
278    """Download data from synapse.
279
280    Requires the synapseclient python library.
281
282    Args:
283        path: The path for saving the data.
284        entity: The name of the data to download from synapse.
285        download: Whether to download the data if it is not saved at `path` yet.
286    """
287    if not download:
288        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
289
290    if synapseclient is None:
291        raise RuntimeError(
292            "You must install 'synapseclient' to download files from 'synapse'. "
293            "Remember to create an account and generate an authentication code for your account. "
294            "Please follow the documentation for details on creating the '~/.synapseConfig' file here: "
295            "https://python-docs.synapse.org/tutorials/authentication/."
296        )
297
298    assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'."
299
300    # Download all files in the folder.
301    syn = synapseclient.Synapse()
302    syn.login()  # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'.
303    synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)
304
305
306def update_kwargs(kwargs, key, value, msg=None):
307    """@private
308    """
309    if key in kwargs:
310        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
311        warn(msg)
312    kwargs[key] = value
313    return kwargs
314
315
316def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
317    """Unpack a tar archive.
318
319    Args:
320        tar_path: Path to the tar file.
321        dst: Where to unpack the archive.
322        remove: Whether to remove the tar file after unpacking.
323    """
324    import tarfile
325
326    if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"):
327        access_mode = "r:gz"
328    elif tar_path.endswith(".tar"):
329        access_mode = "r:"
330    else:
331        raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.")
332
333    tar = tarfile.open(tar_path, access_mode)
334    tar.extractall(dst)
335    tar.close()
336
337    if remove:
338        os.remove(tar_path)
339
340
341def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
342    """Unpack a rar archive.
343
344    Args:
345        rar_path: Path to the rar file.
346        dst: Where to unpack the archive.
347        remove: Whether to remove the tar file after unpacking.
348        use_rarfile: Whether to use the rarfile library or aspose.zip.
349    """
350    if use_rarfile:
351        import rarfile
352        with rarfile.RarFile(rar_path) as f:
353            f.extractall(path=dst)
354    else:
355        import aspose.zip as az
356        with az.rar.RarArchive(rar_path) as archive:
357            archive.extract_to_directory(dst)
358
359    if remove:
360        os.remove(rar_path)
361
362
363def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
364    """Unpack a zip archive.
365
366    Args:
367        zip_path: Path to the zip file.
368        dst: Where to unpack the archive.
369        remove: Whether to remove the tar file after unpacking.
370    """
371    with zipfile.ZipFile(zip_path, "r") as f:
372        f.extractall(dst)
373    if remove:
374        os.remove(zip_path)
375
376
377def split_kwargs(function, **kwargs):
378    """@private
379    """
380    function_parameters = inspect.signature(function).parameters
381    parameter_names = list(function_parameters.keys())
382    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
383    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
384    return kwargs, other_kwargs
385
386
387# this adds the default transforms for 'raw_transform' and 'transform'
388# in case these were not specified in the kwargs
389# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class
390# is used directly, e.g. in the LiveCell Loader
391def ensure_transforms(ndim, **kwargs):
392    """@private
393    """
394    if "raw_transform" not in kwargs:
395        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
396    if "transform" not in kwargs:
397        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
398    return kwargs
399
400
401def add_instance_label_transform(
402    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
403):
404    """@private
405    """
406    if binary_is_exclusive:
407        assert sum((offsets is not None, boundaries, binary)) <= 1
408    else:
409        assert sum((offsets is not None, boundaries)) <= 1
410    if offsets is not None:
411        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
412                                                                      add_binary_target=add_binary_target,
413                                                                      add_mask=True)
414        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
415        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
416        label_dtype = torch.float32
417    elif boundaries:
418        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
419        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
420        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
421        label_dtype = torch.float32
422    elif binary:
423        label_transform = torch_em.transform.label.labels_to_binary
424        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
425        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
426        label_dtype = torch.float32
427    return kwargs, label_dtype
428
429
430def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None):
431    """@private
432    """
433    # Checks for raw_transform and label_transform incoming values.
434    # If yes, it will automatically merge these two transforms to apply them together.
435    if resize_inputs:
436        assert isinstance(resize_kwargs, dict)
437
438        target_shape = resize_kwargs.get("patch_shape")
439        if len(resize_kwargs["patch_shape"]) == 3:
440            # we only need the XY dimensions to reshape the inputs along them.
441            target_shape = target_shape[1:]
442            # we provide the Z dimension value to return the desired number of slices and not the whole volume
443            kwargs["z_ext"] = resize_kwargs["patch_shape"][0]
444
445        raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"])
446        label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True)
447
448        # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded.
449        patch_shape = None
450
451    if ensure_rgb is None:
452        raw_trafos = []
453    else:
454        assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function."
455        raw_trafos = [ensure_rgb]
456
457    if "raw_transform" in kwargs:
458        raw_trafos.extend([raw_trafo, kwargs["raw_transform"]])
459    else:
460        raw_trafos.extend([raw_trafo, get_raw_transform()])
461
462    kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False)
463
464    if "label_transform" in kwargs:
465        trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False)
466        kwargs["label_transform"] = trafo
467    else:
468        kwargs["label_transform"] = label_trafo
469
470    return kwargs, patch_shape
471
472
473def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray:
474    """Generate a label mask from a contour defined in a xml annotation file.
475
476    Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
477
478    Args:
479        shape: The image shape.
480        xml_file: The path to the xml file with contour annotations.
481
482    Returns:
483        The label mask.
484    """
485    # DOM object created by the minidom parser
486    xDoc = minidom.parse(xml_file)
487
488    # List of all Region tags
489    regions = xDoc.getElementsByTagName('Region')
490
491    # List which will store the vertices for each region
492    xy = []
493    for region in regions:
494        # Loading all the vertices in the region
495        vertices = region.getElementsByTagName('Vertex')
496
497        # The vertices of a region will be stored in a array
498        vw = np.zeros((len(vertices), 2))
499
500        for index, vertex in enumerate(vertices):
501            # Storing the values of x and y coordinate after conversion
502            vw[index][0] = float(vertex.getAttribute('X'))
503            vw[index][1] = float(vertex.getAttribute('Y'))
504
505        # Append the vertices of a region
506        xy.append(np.int32(vw))
507
508    # Creating a completely black image
509    mask = np.zeros(shape, np.float32)
510
511    for i, contour in enumerate(xy):
512        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
513        mask[r, c] = i
514    return mask
515
516
517# This function could be extended to convert WSIs (or modalities with multiple resolutions).
518def convert_svs_to_array(
519    path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None,
520) -> np.ndarray:
521    """Convert a .svs file for WSI imagging to a numpy array.
522
523    Requires the tiffslide python library.
524    The function can load multi-resolution images. You can specify the resolution level via `level`.
525
526    Args:
527        path: File path ath to the svs file.
528        location: Pixel location (x, y) in level 0 of the image.
529        level: Target level used to read the image.
530        img_size: Size of the image. If None, the shape of the image at `level` is used.
531
532    Returns:
533        The image as numpy array.
534    """
535    from tiffslide import TiffSlide
536
537    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
538    _slide = TiffSlide(path)
539    if img_size is None:
540        img_size = _slide.level_dimensions[0]
541    return _slide.read_region(location=location, level=level, size=img_size, as_array=True)
542
543
544def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
545    """Download data from the CryoET Data Portal.
546
547    Requires the cryoet-data-portal python library.
548
549    Args:
550        path: The path for saving the data.
551        dataset_id: The id of the data to download from the portal.
552        download: Whether to download the data if it is not saved at `path` yet.
553
554    Returns:
555        The file path to the downloaded data.
556    """
557    if Client is None or Dataset is None:
558        raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'")
559
560    output_path = os.path.join(path, str(dataset_id))
561    if os.path.exists(output_path):
562        return output_path
563
564    if not download:
565        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
566
567    client = Client()
568    dataset = Dataset.get_by_id(client, dataset_id)
569    dataset.download_everything(dest_path=path)
570
571    return output_path
def get_checksum(filename: str) -> str:
77def get_checksum(filename: str) -> str:
78    """Get the SHA256 checksum of a file.
79
80    Args:
81        filename: The filepath.
82
83    Returns:
84        The checksum.
85    """
86    with open(filename, "rb") as f:
87        file_ = f.read()
88        checksum = hashlib.sha256(file_).hexdigest()
89    return checksum

Get the SHA256 checksum of a file.

Arguments:
  • filename: The filepath.
Returns:

The checksum.

def download_source( path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
107def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None:
108    """Download data via https.
109
110    Args:
111        path: The path for saving the data.
112        url: The url of the data.
113        download: Whether to download the data if it is not saved at `path` yet.
114        checksum: The expected checksum of the data.
115        verify: Whether to verify the https address.
116    """
117    if os.path.exists(path):
118        return
119    if not download:
120        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
121
122    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
123        r.raise_for_status()  # check for error
124        file_size = int(r.headers.get("Content-Length", 0))
125        desc = f"Download {url} to {path}"
126        if file_size == 0:
127            desc += " (unknown file size)"
128        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
129            copyfileobj(r_raw, f)
130
131    _check_checksum(path, checksum)

Download data via https.

Arguments:
  • path: The path for saving the data.
  • url: The url of the data.
  • download: Whether to download the data if it is not saved at path yet.
  • checksum: The expected checksum of the data.
  • verify: Whether to verify the https address.
def download_source_gdrive( path: str, url: str, download: bool, checksum: Optional[str] = None, download_type: Literal['zip', 'folder'] = 'zip', expected_samples: int = 10000, quiet: bool = True) -> None:
134def download_source_gdrive(
135    path: str,
136    url: str,
137    download: bool,
138    checksum: Optional[str] = None,
139    download_type: Literal["zip", "folder"] = "zip",
140    expected_samples: int = 10000,
141    quiet: bool = True,
142) -> None:
143    """Download data from google drive.
144
145    Args:
146        path: The path for saving the data.
147        url: The url of the data.
148        download: Whether to download the data if it is not saved at `path` yet.
149        checksum: The expected checksum of the data.
150        download_type: The download type, either 'zip' or 'folder'.
151        expected_samples: The maximal number of samples in the folder.
152        quiet: Whether to download quietly.
153    """
154    if os.path.exists(path):
155        return
156
157    if not download:
158        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
159
160    if gdown is None:
161        raise RuntimeError(
162            "Need gdown library to download data from google drive. "
163            "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
164        )
165
166    print("Downloading the files. Might take a few minutes...")
167
168    if download_type == "zip":
169        gdown.download(url, path, quiet=quiet)
170        _check_checksum(path, checksum)
171    elif download_type == "folder":
172        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'."
173        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
174        gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
175    else:
176        raise ValueError("`download_path` argument expects either `zip`/`folder`")
177
178    print("Download completed.")

Download data from google drive.

Arguments:
  • path: The path for saving the data.
  • url: The url of the data.
  • download: Whether to download the data if it is not saved at path yet.
  • checksum: The expected checksum of the data.
  • download_type: The download type, either 'zip' or 'folder'.
  • expected_samples: The maximal number of samples in the folder.
  • quiet: Whether to download quietly.
def download_source_empiar(path: str, access_id: str, download: bool) -> str:
181def download_source_empiar(path: str, access_id: str, download: bool) -> str:
182    """Download data from EMPIAR.
183
184    Requires the ascp command from the aspera CLI.
185
186    Args:
187        path: The path for saving the data.
188        access_id: The EMPIAR accession id of the data to download.
189        download: Whether to download the data if it is not saved at `path` yet.
190
191    Returns:
192        The path to the downloaded data.
193    """
194    download_path = os.path.join(path, access_id)
195
196    if os.path.exists(download_path):
197        return download_path
198    if not download:
199        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
200
201    if which("ascp") is None:
202        raise RuntimeError(
203            "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'."
204        )
205
206    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
207    if not os.path.exists(key_file):
208        conda_root = os.environ["CONDA_PREFIX"]
209        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
210
211    if not os.path.exists(key_file):
212        raise RuntimeError("Could not find the aspera ssh keyfile")
213
214    cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path]
215    run(cmd)
216
217    return download_path

Download data from EMPIAR.

Requires the ascp command from the aspera CLI.

Arguments:
  • path: The path for saving the data.
  • access_id: The EMPIAR accession id of the data to download.
  • download: Whether to download the data if it is not saved at path yet.
Returns:

The path to the downloaded data.

def download_source_kaggle( path: str, dataset_name: str, download: bool, competition: bool = False):
220def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False):
221    """Download data from Kaggle.
222
223    Requires the Kaggle API.
224
225    Args:
226        path: The path for saving the data.
227        dataset_name: The name of the dataset to download.
228        download: Whether to download the data if it is not saved at `path` yet.
229        competition: Whether this data is from a competition and requires the kaggle.competition API.
230    """
231    if not download:
232        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
233
234    try:
235        from kaggle.api.kaggle_api_extended import KaggleApi
236    except ModuleNotFoundError:
237        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
238        msg += "After you have installed kaggle, you would need an API token. "
239        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
240        raise ModuleNotFoundError(msg)
241
242    api = KaggleApi()
243    api.authenticate()
244
245    if competition:
246        api.competition_download_files(competition=dataset_name, path=path, quiet=False)
247    else:
248        api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)

Download data from Kaggle.

Requires the Kaggle API.

Arguments:
  • path: The path for saving the data.
  • dataset_name: The name of the dataset to download.
  • download: Whether to download the data if it is not saved at path yet.
  • competition: Whether this data is from a competition and requires the kaggle.competition API.
def download_source_tcia(path, url, dst, csv_filename, download):
251def download_source_tcia(path, url, dst, csv_filename, download):
252    """Download data from TCIA.
253
254    Requires the tcia_utils python package.
255
256    Args:
257        path: The path for saving the data.
258        url: The URL to the TCIA dataset.
259        dst:
260        csv_filename:
261        download: Whether to download the data if it is not saved at `path` yet.
262    """
263    if nbia is None:
264        raise RuntimeError("Requires the tcia_utils python package.")
265    if not download:
266        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
267    assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest."
268
269    # Downloads the manifest file from the collection page.
270    manifest = requests.get(url=url)
271    with open(path, "wb") as f:
272        f.write(manifest.content)
273
274    # This part extracts the UIDs from the manifests and downloads them.
275    nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)

Download data from TCIA.

Requires the tcia_utils python package.

Arguments:
  • path: The path for saving the data.
  • url: The URL to the TCIA dataset.
  • dst:
  • csv_filename:
  • download: Whether to download the data if it is not saved at path yet.
def download_source_synapse(path: str, entity: str, download: bool) -> None:
278def download_source_synapse(path: str, entity: str, download: bool) -> None:
279    """Download data from synapse.
280
281    Requires the synapseclient python library.
282
283    Args:
284        path: The path for saving the data.
285        entity: The name of the data to download from synapse.
286        download: Whether to download the data if it is not saved at `path` yet.
287    """
288    if not download:
289        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
290
291    if synapseclient is None:
292        raise RuntimeError(
293            "You must install 'synapseclient' to download files from 'synapse'. "
294            "Remember to create an account and generate an authentication code for your account. "
295            "Please follow the documentation for details on creating the '~/.synapseConfig' file here: "
296            "https://python-docs.synapse.org/tutorials/authentication/."
297        )
298
299    assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'."
300
301    # Download all files in the folder.
302    syn = synapseclient.Synapse()
303    syn.login()  # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'.
304    synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)

Download data from synapse.

Requires the synapseclient python library.

Arguments:
  • path: The path for saving the data.
  • entity: The name of the data to download from synapse.
  • download: Whether to download the data if it is not saved at path yet.
def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
317def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None:
318    """Unpack a tar archive.
319
320    Args:
321        tar_path: Path to the tar file.
322        dst: Where to unpack the archive.
323        remove: Whether to remove the tar file after unpacking.
324    """
325    import tarfile
326
327    if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"):
328        access_mode = "r:gz"
329    elif tar_path.endswith(".tar"):
330        access_mode = "r:"
331    else:
332        raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.")
333
334    tar = tarfile.open(tar_path, access_mode)
335    tar.extractall(dst)
336    tar.close()
337
338    if remove:
339        os.remove(tar_path)

Unpack a tar archive.

Arguments:
  • tar_path: Path to the tar file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
def unzip_rarfile( rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
342def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None:
343    """Unpack a rar archive.
344
345    Args:
346        rar_path: Path to the rar file.
347        dst: Where to unpack the archive.
348        remove: Whether to remove the tar file after unpacking.
349        use_rarfile: Whether to use the rarfile library or aspose.zip.
350    """
351    if use_rarfile:
352        import rarfile
353        with rarfile.RarFile(rar_path) as f:
354            f.extractall(path=dst)
355    else:
356        import aspose.zip as az
357        with az.rar.RarArchive(rar_path) as archive:
358            archive.extract_to_directory(dst)
359
360    if remove:
361        os.remove(rar_path)

Unpack a rar archive.

Arguments:
  • rar_path: Path to the rar file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
  • use_rarfile: Whether to use the rarfile library or aspose.zip.
def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
364def unzip(zip_path: str, dst: str, remove: bool = True) -> None:
365    """Unpack a zip archive.
366
367    Args:
368        zip_path: Path to the zip file.
369        dst: Where to unpack the archive.
370        remove: Whether to remove the tar file after unpacking.
371    """
372    with zipfile.ZipFile(zip_path, "r") as f:
373        f.extractall(dst)
374    if remove:
375        os.remove(zip_path)

Unpack a zip archive.

Arguments:
  • zip_path: Path to the zip file.
  • dst: Where to unpack the archive.
  • remove: Whether to remove the tar file after unpacking.
def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> numpy.ndarray:
474def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray:
475    """Generate a label mask from a contour defined in a xml annotation file.
476
477    Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
478
479    Args:
480        shape: The image shape.
481        xml_file: The path to the xml file with contour annotations.
482
483    Returns:
484        The label mask.
485    """
486    # DOM object created by the minidom parser
487    xDoc = minidom.parse(xml_file)
488
489    # List of all Region tags
490    regions = xDoc.getElementsByTagName('Region')
491
492    # List which will store the vertices for each region
493    xy = []
494    for region in regions:
495        # Loading all the vertices in the region
496        vertices = region.getElementsByTagName('Vertex')
497
498        # The vertices of a region will be stored in a array
499        vw = np.zeros((len(vertices), 2))
500
501        for index, vertex in enumerate(vertices):
502            # Storing the values of x and y coordinate after conversion
503            vw[index][0] = float(vertex.getAttribute('X'))
504            vw[index][1] = float(vertex.getAttribute('Y'))
505
506        # Append the vertices of a region
507        xy.append(np.int32(vw))
508
509    # Creating a completely black image
510    mask = np.zeros(shape, np.float32)
511
512    for i, contour in enumerate(xy):
513        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
514        mask[r, c] = i
515    return mask

Generate a label mask from a contour defined in a xml annotation file.

Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb

Arguments:
  • shape: The image shape.
  • xml_file: The path to the xml file with contour annotations.
Returns:

The label mask.

def convert_svs_to_array( path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None) -> numpy.ndarray:
519def convert_svs_to_array(
520    path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None,
521) -> np.ndarray:
522    """Convert a .svs file for WSI imagging to a numpy array.
523
524    Requires the tiffslide python library.
525    The function can load multi-resolution images. You can specify the resolution level via `level`.
526
527    Args:
528        path: File path ath to the svs file.
529        location: Pixel location (x, y) in level 0 of the image.
530        level: Target level used to read the image.
531        img_size: Size of the image. If None, the shape of the image at `level` is used.
532
533    Returns:
534        The image as numpy array.
535    """
536    from tiffslide import TiffSlide
537
538    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
539    _slide = TiffSlide(path)
540    if img_size is None:
541        img_size = _slide.level_dimensions[0]
542    return _slide.read_region(location=location, level=level, size=img_size, as_array=True)

Convert a .svs file for WSI imagging to a numpy array.

Requires the tiffslide python library. The function can load multi-resolution images. You can specify the resolution level via level.

Arguments:
  • path: File path ath to the svs file.
  • location: Pixel location (x, y) in level 0 of the image.
  • level: Target level used to read the image.
  • img_size: Size of the image. If None, the shape of the image at level is used.
Returns:

The image as numpy array.

def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
545def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str:
546    """Download data from the CryoET Data Portal.
547
548    Requires the cryoet-data-portal python library.
549
550    Args:
551        path: The path for saving the data.
552        dataset_id: The id of the data to download from the portal.
553        download: Whether to download the data if it is not saved at `path` yet.
554
555    Returns:
556        The file path to the downloaded data.
557    """
558    if Client is None or Dataset is None:
559        raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'")
560
561    output_path = os.path.join(path, str(dataset_id))
562    if os.path.exists(output_path):
563        return output_path
564
565    if not download:
566        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.")
567
568    client = Client()
569    dataset = Dataset.get_by_id(client, dataset_id)
570    dataset.download_everything(dest_path=path)
571
572    return output_path

Download data from the CryoET Data Portal.

Requires the cryoet-data-portal python library.

Arguments:
  • path: The path for saving the data.
  • dataset_id: The id of the data to download from the portal.
  • download: Whether to download the data if it is not saved at path yet.
Returns:

The file path to the downloaded data.