torch_em.data.datasets.util
1import os 2import hashlib 3import inspect 4import zipfile 5import requests 6from tqdm import tqdm 7from warnings import warn 8from subprocess import run 9from xml.dom import minidom 10from packaging import version 11from shutil import copyfileobj, which 12 13from typing import Optional, Tuple, Literal 14 15import numpy as np 16from skimage.draw import polygon 17 18import torch 19 20import torch_em 21from torch_em.transform import get_raw_transform 22from torch_em.transform.generic import ResizeLongestSideInputs, Compose 23 24try: 25 import gdown 26except ImportError: 27 gdown = None 28 29try: 30 from tcia_utils import nbia 31except ModuleNotFoundError: 32 nbia = None 33 34try: 35 from cryoet_data_portal import Client, Dataset 36except ImportError: 37 Client, Dataset = None, None 38 39try: 40 import synapseclient 41 import synapseutils 42except ImportError: 43 synapseclient, synapseutils = None, None 44 45 46BIOIMAGEIO_IDS = { 47 "covid_if": "ilastik/covid_if_training_data", 48 "cremi": "ilastik/cremi_training_data", 49 "dsb": "ilastik/stardist_dsb_training_data", 50 "hpa": "", # not on bioimageio yet 51 "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge", 52 "kasthuri": "", # not on bioimageio yet: 53 "livecell": "ilastik/livecell_dataset", 54 "lucchi": "", # not on bioimageio yet: 55 "mitoem": "ilastik/mitoem_segmentation_challenge", 56 "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018", 57 "ovules": "", # not on bioimageio yet 58 "plantseg_root": "ilastik/plantseg_root", 59 "plantseg_ovules": "ilastik/plantseg_ovules", 60 "platynereis": "ilastik/platynereis_em_training_data", 61 "snemi": "", # not on bioimagegio yet 62 "uro_cell": "", # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693 63 "vnc": "ilastik/vnc", 64} 65"""@private 66""" 67 68 69def get_bioimageio_dataset_id(dataset_name): 70 """@private 71 """ 72 assert dataset_name in BIOIMAGEIO_IDS 73 return BIOIMAGEIO_IDS[dataset_name] 74 75 76def get_checksum(filename: str) -> str: 77 """Get the SHA256 checksum of a file. 78 79 Args: 80 filename: The filepath. 81 82 Returns: 83 The checksum. 84 """ 85 with open(filename, "rb") as f: 86 file_ = f.read() 87 checksum = hashlib.sha256(file_).hexdigest() 88 return checksum 89 90 91def _check_checksum(path, checksum): 92 if checksum is not None: 93 this_checksum = get_checksum(path) 94 if this_checksum != checksum: 95 raise RuntimeError( 96 "The checksum of the download does not match the expected checksum." 97 f"Expected: {checksum}, got: {this_checksum}" 98 ) 99 print("Download successful and checksums agree.") 100 else: 101 warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.") 102 103 104# this needs to be extended to support download from s3 via boto, 105# if we get a resource that is available via s3 without support for http 106def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None: 107 """Download data via https. 108 109 Args: 110 path: The path for saving the data. 111 url: The url of the data. 112 download: Whether to download the data if it is not saved at `path` yet. 113 checksum: The expected checksum of the data. 114 verify: Whether to verify the https address. 115 """ 116 if os.path.exists(path): 117 return 118 if not download: 119 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 120 121 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 122 r.raise_for_status() # check for error 123 file_size = int(r.headers.get("Content-Length", 0)) 124 desc = f"Download {url} to {path}" 125 if file_size == 0: 126 desc += " (unknown file size)" 127 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 128 copyfileobj(r_raw, f) 129 130 _check_checksum(path, checksum) 131 132 133def download_source_gdrive( 134 path: str, 135 url: str, 136 download: bool, 137 checksum: Optional[str] = None, 138 download_type: Literal["zip", "folder"] = "zip", 139 expected_samples: int = 10000, 140 quiet: bool = True, 141) -> None: 142 """Download data from google drive. 143 144 Args: 145 path: The path for saving the data. 146 url: The url of the data. 147 download: Whether to download the data if it is not saved at `path` yet. 148 checksum: The expected checksum of the data. 149 download_type: The download type, either 'zip' or 'folder'. 150 expected_samples: The maximal number of samples in the folder. 151 quiet: Whether to download quietly. 152 """ 153 if os.path.exists(path): 154 return 155 156 if not download: 157 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 158 159 if gdown is None: 160 raise RuntimeError( 161 "Need gdown library to download data from google drive. " 162 "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'." 163 ) 164 165 print("Downloading the files. Might take a few minutes...") 166 167 if download_type == "zip": 168 gdown.download(url, path, quiet=quiet) 169 _check_checksum(path, checksum) 170 elif download_type == "folder": 171 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'." 172 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 173 gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True) 174 else: 175 raise ValueError("`download_path` argument expects either `zip`/`folder`") 176 177 print("Download completed.") 178 179 180def download_source_empiar(path: str, access_id: str, download: bool) -> str: 181 """Download data from EMPIAR. 182 183 Requires the ascp command from the aspera CLI. 184 185 Args: 186 path: The path for saving the data. 187 access_id: The EMPIAR accession id of the data to download. 188 download: Whether to download the data if it is not saved at `path` yet. 189 190 Returns: 191 The path to the downloaded data. 192 """ 193 download_path = os.path.join(path, access_id) 194 195 if os.path.exists(download_path): 196 return download_path 197 if not download: 198 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 199 200 if which("ascp") is None: 201 raise RuntimeError( 202 "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'." 203 ) 204 205 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 206 if not os.path.exists(key_file): 207 conda_root = os.environ["CONDA_PREFIX"] 208 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 209 210 if not os.path.exists(key_file): 211 raise RuntimeError("Could not find the aspera ssh keyfile") 212 213 cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path] 214 run(cmd) 215 216 return download_path 217 218 219def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False): 220 """Download data from Kaggle. 221 222 Requires the Kaggle API. 223 224 Args: 225 path: The path for saving the data. 226 dataset_name: The name of the dataset to download. 227 download: Whether to download the data if it is not saved at `path` yet. 228 competition: Whether this data is from a competition and requires the kaggle.competition API. 229 """ 230 if not download: 231 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 232 233 try: 234 from kaggle.api.kaggle_api_extended import KaggleApi 235 except ModuleNotFoundError: 236 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 237 msg += "After you have installed kaggle, you would need an API token. " 238 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 239 raise ModuleNotFoundError(msg) 240 241 api = KaggleApi() 242 api.authenticate() 243 244 if competition: 245 api.competition_download_files(competition=dataset_name, path=path, quiet=False) 246 else: 247 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) 248 249 250def download_source_tcia(path, url, dst, csv_filename, download): 251 """Download data from TCIA. 252 253 Requires the tcia_utils python package. 254 255 Args: 256 path: The path for saving the data. 257 url: The URL to the TCIA dataset. 258 dst: 259 csv_filename: 260 download: Whether to download the data if it is not saved at `path` yet. 261 """ 262 if nbia is None: 263 raise RuntimeError("Requires the tcia_utils python package.") 264 if not download: 265 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 266 assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest." 267 268 # Downloads the manifest file from the collection page. 269 manifest = requests.get(url=url) 270 with open(path, "wb") as f: 271 f.write(manifest.content) 272 273 # This part extracts the UIDs from the manifests and downloads them. 274 nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename) 275 276 277def download_source_synapse(path: str, entity: str, download: bool) -> None: 278 """Download data from synapse. 279 280 Requires the synapseclient python library. 281 282 Args: 283 path: The path for saving the data. 284 entity: The name of the data to download from synapse. 285 download: Whether to download the data if it is not saved at `path` yet. 286 """ 287 if not download: 288 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 289 290 if synapseclient is None: 291 raise RuntimeError( 292 "You must install 'synapseclient' to download files from 'synapse'. " 293 "Remember to create an account and generate an authentication code for your account. " 294 "Please follow the documentation for details on creating the '~/.synapseConfig' file here: " 295 "https://python-docs.synapse.org/tutorials/authentication/." 296 ) 297 298 assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'." 299 300 # Download all files in the folder. 301 syn = synapseclient.Synapse() 302 syn.login() # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'. 303 synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path) 304 305 306def update_kwargs(kwargs, key, value, msg=None): 307 """@private 308 """ 309 if key in kwargs: 310 msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg 311 warn(msg) 312 kwargs[key] = value 313 return kwargs 314 315 316def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None: 317 """Unpack a tar archive. 318 319 Args: 320 tar_path: Path to the tar file. 321 dst: Where to unpack the archive. 322 remove: Whether to remove the tar file after unpacking. 323 """ 324 import tarfile 325 326 if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"): 327 access_mode = "r:gz" 328 elif tar_path.endswith(".tar"): 329 access_mode = "r:" 330 else: 331 raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.") 332 333 tar = tarfile.open(tar_path, access_mode) 334 tar.extractall(dst) 335 tar.close() 336 337 if remove: 338 os.remove(tar_path) 339 340 341def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None: 342 """Unpack a rar archive. 343 344 Args: 345 rar_path: Path to the rar file. 346 dst: Where to unpack the archive. 347 remove: Whether to remove the tar file after unpacking. 348 use_rarfile: Whether to use the rarfile library or aspose.zip. 349 """ 350 def _extract_with_rarfile(): 351 import rarfile 352 with rarfile.RarFile(rar_path) as archive: 353 archive.extractall(path=dst) 354 355 def _extract_with_aspose(): 356 import aspose.zip as az 357 with az.rar.RarArchive(rar_path) as archive: 358 archive.extract_to_directory(dst) 359 360 extractors = [ 361 ('rarfile', _extract_with_rarfile), ('aspose.zip', _extract_with_aspose), 362 ] if use_rarfile else [('aspose.zip', _extract_with_aspose)] 363 364 errors = [] 365 for name, extractor in extractors: 366 try: 367 extractor() 368 break 369 except Exception as err: 370 errors.append((name, err)) 371 if len(errors) < len(extractors): 372 next_name = extractors[len(errors)][0] 373 warn(f"Extraction with '{name}' failed for {rar_path} ({err}). Falling back to '{next_name}'.") 374 else: 375 backends = ', '.join(f"'{name}'" for name, _ in extractors) 376 raise RuntimeError( 377 f"Failed to extract rar archive {rar_path} with {backends}. " 378 "Please ensure one of the supported backends is installed and can read this archive." 379 ) from errors[-1][1] 380 381 if remove: 382 os.remove(rar_path) 383 384 385def unzip(zip_path: str, dst: str, remove: bool = True) -> None: 386 """Unpack a zip archive. 387 388 Args: 389 zip_path: Path to the zip file. 390 dst: Where to unpack the archive. 391 remove: Whether to remove the tar file after unpacking. 392 """ 393 with zipfile.ZipFile(zip_path, "r") as f: 394 f.extractall(dst) 395 if remove: 396 os.remove(zip_path) 397 398 399def split_kwargs(function, **kwargs): 400 """@private 401 """ 402 function_parameters = inspect.signature(function).parameters 403 parameter_names = list(function_parameters.keys()) 404 other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names} 405 kwargs = {k: v for k, v in kwargs.items() if k in parameter_names} 406 return kwargs, other_kwargs 407 408 409# this adds the default transforms for 'raw_transform' and 'transform' 410# in case these were not specified in the kwargs 411# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class 412# is used directly, e.g. in the LiveCell Loader 413def ensure_transforms(ndim, **kwargs): 414 """@private 415 """ 416 if "raw_transform" not in kwargs: 417 kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform()) 418 if "transform" not in kwargs: 419 kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim)) 420 return kwargs 421 422 423def add_instance_label_transform( 424 kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True, 425): 426 """@private 427 """ 428 if binary_is_exclusive: 429 assert sum((offsets is not None, boundaries, binary)) <= 1 430 else: 431 assert sum((offsets is not None, boundaries)) <= 1 432 if offsets is not None: 433 label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets, 434 add_binary_target=add_binary_target, 435 add_mask=True) 436 msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." 437 kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg) 438 label_dtype = torch.float32 439 elif boundaries: 440 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target) 441 msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 442 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 443 label_dtype = torch.float32 444 elif binary: 445 label_transform = torch_em.transform.label.labels_to_binary 446 msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 447 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 448 label_dtype = torch.float32 449 return kwargs, label_dtype 450 451 452def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None): 453 """@private 454 """ 455 # Checks for raw_transform and label_transform incoming values. 456 # If yes, it will automatically merge these two transforms to apply them together. 457 if resize_inputs: 458 assert isinstance(resize_kwargs, dict) 459 460 target_shape = resize_kwargs.get("patch_shape") 461 if len(resize_kwargs["patch_shape"]) == 3: 462 # we only need the XY dimensions to reshape the inputs along them. 463 target_shape = target_shape[1:] 464 # we provide the Z dimension value to return the desired number of slices and not the whole volume 465 kwargs["z_ext"] = resize_kwargs["patch_shape"][0] 466 467 raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"]) 468 label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True) 469 470 # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded. 471 patch_shape = None 472 473 if ensure_rgb is None: 474 raw_trafos = [] 475 else: 476 assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function." 477 raw_trafos = [ensure_rgb] 478 479 if "raw_transform" in kwargs: 480 raw_trafos.extend([raw_trafo, kwargs["raw_transform"]]) 481 else: 482 raw_trafos.extend([raw_trafo, get_raw_transform()]) 483 484 kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False) 485 486 if "label_transform" in kwargs: 487 trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False) 488 kwargs["label_transform"] = trafo 489 else: 490 kwargs["label_transform"] = label_trafo 491 492 return kwargs, patch_shape 493 494 495def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray: 496 """Generate a label mask from a contour defined in a xml annotation file. 497 498 Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 499 500 Args: 501 shape: The image shape. 502 xml_file: The path to the xml file with contour annotations. 503 504 Returns: 505 The label mask. 506 """ 507 # DOM object created by the minidom parser 508 xDoc = minidom.parse(xml_file) 509 510 # List of all Region tags 511 regions = xDoc.getElementsByTagName('Region') 512 513 # List which will store the vertices for each region 514 xy = [] 515 for region in regions: 516 # Loading all the vertices in the region 517 vertices = region.getElementsByTagName('Vertex') 518 519 # The vertices of a region will be stored in a array 520 vw = np.zeros((len(vertices), 2)) 521 522 for index, vertex in enumerate(vertices): 523 # Storing the values of x and y coordinate after conversion 524 vw[index][0] = float(vertex.getAttribute('X')) 525 vw[index][1] = float(vertex.getAttribute('Y')) 526 527 # Append the vertices of a region 528 xy.append(np.int32(vw)) 529 530 # Creating a completely black image 531 mask = np.zeros(shape, np.float32) 532 533 for i, contour in enumerate(xy): 534 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 535 mask[r, c] = i 536 return mask 537 538 539# This function could be extended to convert WSIs (or modalities with multiple resolutions). 540def convert_svs_to_array( 541 path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None, 542) -> np.ndarray: 543 """Convert a .svs file for WSI imagging to a numpy array. 544 545 Requires the tiffslide python library. 546 The function can load multi-resolution images. You can specify the resolution level via `level`. 547 548 Args: 549 path: File path ath to the svs file. 550 location: Pixel location (x, y) in level 0 of the image. 551 level: Target level used to read the image. 552 img_size: Size of the image. If None, the shape of the image at `level` is used. 553 554 Returns: 555 The image as numpy array. 556 """ 557 from tiffslide import TiffSlide 558 559 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 560 _slide = TiffSlide(path) 561 if img_size is None: 562 img_size = _slide.level_dimensions[0] 563 return _slide.read_region(location=location, level=level, size=img_size, as_array=True) 564 565 566def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str: 567 """Download data from the CryoET Data Portal. 568 569 Requires the cryoet-data-portal python library. 570 571 Args: 572 path: The path for saving the data. 573 dataset_id: The id of the data to download from the portal. 574 download: Whether to download the data if it is not saved at `path` yet. 575 576 Returns: 577 The file path to the downloaded data. 578 """ 579 if Client is None or Dataset is None: 580 raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'") 581 582 output_path = os.path.join(path, str(dataset_id)) 583 if os.path.exists(output_path): 584 return output_path 585 586 if not download: 587 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 588 589 client = Client() 590 dataset = Dataset.get_by_id(client, dataset_id) 591 dataset.download_everything(dest_path=path) 592 593 return output_path
77def get_checksum(filename: str) -> str: 78 """Get the SHA256 checksum of a file. 79 80 Args: 81 filename: The filepath. 82 83 Returns: 84 The checksum. 85 """ 86 with open(filename, "rb") as f: 87 file_ = f.read() 88 checksum = hashlib.sha256(file_).hexdigest() 89 return checksum
Get the SHA256 checksum of a file.
Arguments:
- filename: The filepath.
Returns:
The checksum.
107def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None: 108 """Download data via https. 109 110 Args: 111 path: The path for saving the data. 112 url: The url of the data. 113 download: Whether to download the data if it is not saved at `path` yet. 114 checksum: The expected checksum of the data. 115 verify: Whether to verify the https address. 116 """ 117 if os.path.exists(path): 118 return 119 if not download: 120 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 121 122 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 123 r.raise_for_status() # check for error 124 file_size = int(r.headers.get("Content-Length", 0)) 125 desc = f"Download {url} to {path}" 126 if file_size == 0: 127 desc += " (unknown file size)" 128 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 129 copyfileobj(r_raw, f) 130 131 _check_checksum(path, checksum)
Download data via https.
Arguments:
- path: The path for saving the data.
- url: The url of the data.
- download: Whether to download the data if it is not saved at
pathyet. - checksum: The expected checksum of the data.
- verify: Whether to verify the https address.
134def download_source_gdrive( 135 path: str, 136 url: str, 137 download: bool, 138 checksum: Optional[str] = None, 139 download_type: Literal["zip", "folder"] = "zip", 140 expected_samples: int = 10000, 141 quiet: bool = True, 142) -> None: 143 """Download data from google drive. 144 145 Args: 146 path: The path for saving the data. 147 url: The url of the data. 148 download: Whether to download the data if it is not saved at `path` yet. 149 checksum: The expected checksum of the data. 150 download_type: The download type, either 'zip' or 'folder'. 151 expected_samples: The maximal number of samples in the folder. 152 quiet: Whether to download quietly. 153 """ 154 if os.path.exists(path): 155 return 156 157 if not download: 158 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 159 160 if gdown is None: 161 raise RuntimeError( 162 "Need gdown library to download data from google drive. " 163 "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'." 164 ) 165 166 print("Downloading the files. Might take a few minutes...") 167 168 if download_type == "zip": 169 gdown.download(url, path, quiet=quiet) 170 _check_checksum(path, checksum) 171 elif download_type == "folder": 172 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'." 173 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 174 gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True) 175 else: 176 raise ValueError("`download_path` argument expects either `zip`/`folder`") 177 178 print("Download completed.")
Download data from google drive.
Arguments:
- path: The path for saving the data.
- url: The url of the data.
- download: Whether to download the data if it is not saved at
pathyet. - checksum: The expected checksum of the data.
- download_type: The download type, either 'zip' or 'folder'.
- expected_samples: The maximal number of samples in the folder.
- quiet: Whether to download quietly.
181def download_source_empiar(path: str, access_id: str, download: bool) -> str: 182 """Download data from EMPIAR. 183 184 Requires the ascp command from the aspera CLI. 185 186 Args: 187 path: The path for saving the data. 188 access_id: The EMPIAR accession id of the data to download. 189 download: Whether to download the data if it is not saved at `path` yet. 190 191 Returns: 192 The path to the downloaded data. 193 """ 194 download_path = os.path.join(path, access_id) 195 196 if os.path.exists(download_path): 197 return download_path 198 if not download: 199 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 200 201 if which("ascp") is None: 202 raise RuntimeError( 203 "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'." 204 ) 205 206 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 207 if not os.path.exists(key_file): 208 conda_root = os.environ["CONDA_PREFIX"] 209 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 210 211 if not os.path.exists(key_file): 212 raise RuntimeError("Could not find the aspera ssh keyfile") 213 214 cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path] 215 run(cmd) 216 217 return download_path
Download data from EMPIAR.
Requires the ascp command from the aspera CLI.
Arguments:
- path: The path for saving the data.
- access_id: The EMPIAR accession id of the data to download.
- download: Whether to download the data if it is not saved at
pathyet.
Returns:
The path to the downloaded data.
220def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False): 221 """Download data from Kaggle. 222 223 Requires the Kaggle API. 224 225 Args: 226 path: The path for saving the data. 227 dataset_name: The name of the dataset to download. 228 download: Whether to download the data if it is not saved at `path` yet. 229 competition: Whether this data is from a competition and requires the kaggle.competition API. 230 """ 231 if not download: 232 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 233 234 try: 235 from kaggle.api.kaggle_api_extended import KaggleApi 236 except ModuleNotFoundError: 237 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 238 msg += "After you have installed kaggle, you would need an API token. " 239 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 240 raise ModuleNotFoundError(msg) 241 242 api = KaggleApi() 243 api.authenticate() 244 245 if competition: 246 api.competition_download_files(competition=dataset_name, path=path, quiet=False) 247 else: 248 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
Download data from Kaggle.
Requires the Kaggle API.
Arguments:
- path: The path for saving the data.
- dataset_name: The name of the dataset to download.
- download: Whether to download the data if it is not saved at
pathyet. - competition: Whether this data is from a competition and requires the kaggle.competition API.
251def download_source_tcia(path, url, dst, csv_filename, download): 252 """Download data from TCIA. 253 254 Requires the tcia_utils python package. 255 256 Args: 257 path: The path for saving the data. 258 url: The URL to the TCIA dataset. 259 dst: 260 csv_filename: 261 download: Whether to download the data if it is not saved at `path` yet. 262 """ 263 if nbia is None: 264 raise RuntimeError("Requires the tcia_utils python package.") 265 if not download: 266 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 267 assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest." 268 269 # Downloads the manifest file from the collection page. 270 manifest = requests.get(url=url) 271 with open(path, "wb") as f: 272 f.write(manifest.content) 273 274 # This part extracts the UIDs from the manifests and downloads them. 275 nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)
Download data from TCIA.
Requires the tcia_utils python package.
Arguments:
- path: The path for saving the data.
- url: The URL to the TCIA dataset.
- dst:
- csv_filename:
- download: Whether to download the data if it is not saved at
pathyet.
278def download_source_synapse(path: str, entity: str, download: bool) -> None: 279 """Download data from synapse. 280 281 Requires the synapseclient python library. 282 283 Args: 284 path: The path for saving the data. 285 entity: The name of the data to download from synapse. 286 download: Whether to download the data if it is not saved at `path` yet. 287 """ 288 if not download: 289 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 290 291 if synapseclient is None: 292 raise RuntimeError( 293 "You must install 'synapseclient' to download files from 'synapse'. " 294 "Remember to create an account and generate an authentication code for your account. " 295 "Please follow the documentation for details on creating the '~/.synapseConfig' file here: " 296 "https://python-docs.synapse.org/tutorials/authentication/." 297 ) 298 299 assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'." 300 301 # Download all files in the folder. 302 syn = synapseclient.Synapse() 303 syn.login() # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'. 304 synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)
Download data from synapse.
Requires the synapseclient python library.
Arguments:
- path: The path for saving the data.
- entity: The name of the data to download from synapse.
- download: Whether to download the data if it is not saved at
pathyet.
317def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None: 318 """Unpack a tar archive. 319 320 Args: 321 tar_path: Path to the tar file. 322 dst: Where to unpack the archive. 323 remove: Whether to remove the tar file after unpacking. 324 """ 325 import tarfile 326 327 if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"): 328 access_mode = "r:gz" 329 elif tar_path.endswith(".tar"): 330 access_mode = "r:" 331 else: 332 raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.") 333 334 tar = tarfile.open(tar_path, access_mode) 335 tar.extractall(dst) 336 tar.close() 337 338 if remove: 339 os.remove(tar_path)
Unpack a tar archive.
Arguments:
- tar_path: Path to the tar file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
342def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None: 343 """Unpack a rar archive. 344 345 Args: 346 rar_path: Path to the rar file. 347 dst: Where to unpack the archive. 348 remove: Whether to remove the tar file after unpacking. 349 use_rarfile: Whether to use the rarfile library or aspose.zip. 350 """ 351 def _extract_with_rarfile(): 352 import rarfile 353 with rarfile.RarFile(rar_path) as archive: 354 archive.extractall(path=dst) 355 356 def _extract_with_aspose(): 357 import aspose.zip as az 358 with az.rar.RarArchive(rar_path) as archive: 359 archive.extract_to_directory(dst) 360 361 extractors = [ 362 ('rarfile', _extract_with_rarfile), ('aspose.zip', _extract_with_aspose), 363 ] if use_rarfile else [('aspose.zip', _extract_with_aspose)] 364 365 errors = [] 366 for name, extractor in extractors: 367 try: 368 extractor() 369 break 370 except Exception as err: 371 errors.append((name, err)) 372 if len(errors) < len(extractors): 373 next_name = extractors[len(errors)][0] 374 warn(f"Extraction with '{name}' failed for {rar_path} ({err}). Falling back to '{next_name}'.") 375 else: 376 backends = ', '.join(f"'{name}'" for name, _ in extractors) 377 raise RuntimeError( 378 f"Failed to extract rar archive {rar_path} with {backends}. " 379 "Please ensure one of the supported backends is installed and can read this archive." 380 ) from errors[-1][1] 381 382 if remove: 383 os.remove(rar_path)
Unpack a rar archive.
Arguments:
- rar_path: Path to the rar file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
- use_rarfile: Whether to use the rarfile library or aspose.zip.
386def unzip(zip_path: str, dst: str, remove: bool = True) -> None: 387 """Unpack a zip archive. 388 389 Args: 390 zip_path: Path to the zip file. 391 dst: Where to unpack the archive. 392 remove: Whether to remove the tar file after unpacking. 393 """ 394 with zipfile.ZipFile(zip_path, "r") as f: 395 f.extractall(dst) 396 if remove: 397 os.remove(zip_path)
Unpack a zip archive.
Arguments:
- zip_path: Path to the zip file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
496def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray: 497 """Generate a label mask from a contour defined in a xml annotation file. 498 499 Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 500 501 Args: 502 shape: The image shape. 503 xml_file: The path to the xml file with contour annotations. 504 505 Returns: 506 The label mask. 507 """ 508 # DOM object created by the minidom parser 509 xDoc = minidom.parse(xml_file) 510 511 # List of all Region tags 512 regions = xDoc.getElementsByTagName('Region') 513 514 # List which will store the vertices for each region 515 xy = [] 516 for region in regions: 517 # Loading all the vertices in the region 518 vertices = region.getElementsByTagName('Vertex') 519 520 # The vertices of a region will be stored in a array 521 vw = np.zeros((len(vertices), 2)) 522 523 for index, vertex in enumerate(vertices): 524 # Storing the values of x and y coordinate after conversion 525 vw[index][0] = float(vertex.getAttribute('X')) 526 vw[index][1] = float(vertex.getAttribute('Y')) 527 528 # Append the vertices of a region 529 xy.append(np.int32(vw)) 530 531 # Creating a completely black image 532 mask = np.zeros(shape, np.float32) 533 534 for i, contour in enumerate(xy): 535 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 536 mask[r, c] = i 537 return mask
Generate a label mask from a contour defined in a xml annotation file.
Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
Arguments:
- shape: The image shape.
- xml_file: The path to the xml file with contour annotations.
Returns:
The label mask.
541def convert_svs_to_array( 542 path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None, 543) -> np.ndarray: 544 """Convert a .svs file for WSI imagging to a numpy array. 545 546 Requires the tiffslide python library. 547 The function can load multi-resolution images. You can specify the resolution level via `level`. 548 549 Args: 550 path: File path ath to the svs file. 551 location: Pixel location (x, y) in level 0 of the image. 552 level: Target level used to read the image. 553 img_size: Size of the image. If None, the shape of the image at `level` is used. 554 555 Returns: 556 The image as numpy array. 557 """ 558 from tiffslide import TiffSlide 559 560 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 561 _slide = TiffSlide(path) 562 if img_size is None: 563 img_size = _slide.level_dimensions[0] 564 return _slide.read_region(location=location, level=level, size=img_size, as_array=True)
Convert a .svs file for WSI imagging to a numpy array.
Requires the tiffslide python library.
The function can load multi-resolution images. You can specify the resolution level via level.
Arguments:
- path: File path ath to the svs file.
- location: Pixel location (x, y) in level 0 of the image.
- level: Target level used to read the image.
- img_size: Size of the image. If None, the shape of the image at
levelis used.
Returns:
The image as numpy array.
567def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str: 568 """Download data from the CryoET Data Portal. 569 570 Requires the cryoet-data-portal python library. 571 572 Args: 573 path: The path for saving the data. 574 dataset_id: The id of the data to download from the portal. 575 download: Whether to download the data if it is not saved at `path` yet. 576 577 Returns: 578 The file path to the downloaded data. 579 """ 580 if Client is None or Dataset is None: 581 raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'") 582 583 output_path = os.path.join(path, str(dataset_id)) 584 if os.path.exists(output_path): 585 return output_path 586 587 if not download: 588 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 589 590 client = Client() 591 dataset = Dataset.get_by_id(client, dataset_id) 592 dataset.download_everything(dest_path=path) 593 594 return output_path
Download data from the CryoET Data Portal.
Requires the cryoet-data-portal python library.
Arguments:
- path: The path for saving the data.
- dataset_id: The id of the data to download from the portal.
- download: Whether to download the data if it is not saved at
pathyet.
Returns:
The file path to the downloaded data.