torch_em.data.datasets.util
1import os 2import hashlib 3import inspect 4import zipfile 5import requests 6from tqdm import tqdm 7from warnings import warn 8from subprocess import run 9from xml.dom import minidom 10from packaging import version 11from shutil import copyfileobj, which 12 13from typing import Optional, Tuple, Literal 14 15import numpy as np 16from skimage.draw import polygon 17 18import torch 19 20import torch_em 21from torch_em.transform import get_raw_transform 22from torch_em.transform.generic import ResizeLongestSideInputs, Compose 23 24try: 25 import gdown 26except ImportError: 27 gdown = None 28 29try: 30 from tcia_utils import nbia 31except ModuleNotFoundError: 32 nbia = None 33 34try: 35 from cryoet_data_portal import Client, Dataset 36except ImportError: 37 Client, Dataset = None, None 38 39try: 40 import synapseclient 41 import synapseutils 42except ImportError: 43 synapseclient, synapseutils = None, None 44 45 46BIOIMAGEIO_IDS = { 47 "covid_if": "ilastik/covid_if_training_data", 48 "cremi": "ilastik/cremi_training_data", 49 "dsb": "ilastik/stardist_dsb_training_data", 50 "hpa": "", # not on bioimageio yet 51 "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge", 52 "kasthuri": "", # not on bioimageio yet: 53 "livecell": "ilastik/livecell_dataset", 54 "lucchi": "", # not on bioimageio yet: 55 "mitoem": "ilastik/mitoem_segmentation_challenge", 56 "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018", 57 "ovules": "", # not on bioimageio yet 58 "plantseg_root": "ilastik/plantseg_root", 59 "plantseg_ovules": "ilastik/plantseg_ovules", 60 "platynereis": "ilastik/platynereis_em_training_data", 61 "snemi": "", # not on bioimagegio yet 62 "uro_cell": "", # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693 63 "vnc": "ilastik/vnc", 64} 65"""@private 66""" 67 68 69def get_bioimageio_dataset_id(dataset_name): 70 """@private 71 """ 72 assert dataset_name in BIOIMAGEIO_IDS 73 return BIOIMAGEIO_IDS[dataset_name] 74 75 76def get_checksum(filename: str) -> str: 77 """Get the SHA256 checksum of a file. 78 79 Args: 80 filename: The filepath. 81 82 Returns: 83 The checksum. 84 """ 85 with open(filename, "rb") as f: 86 file_ = f.read() 87 checksum = hashlib.sha256(file_).hexdigest() 88 return checksum 89 90 91def _check_checksum(path, checksum): 92 if checksum is not None: 93 this_checksum = get_checksum(path) 94 if this_checksum != checksum: 95 raise RuntimeError( 96 "The checksum of the download does not match the expected checksum." 97 f"Expected: {checksum}, got: {this_checksum}" 98 ) 99 print("Download successful and checksums agree.") 100 else: 101 warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.") 102 103 104# this needs to be extended to support download from s3 via boto, 105# if we get a resource that is available via s3 without support for http 106def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None: 107 """Download data via https. 108 109 Args: 110 path: The path for saving the data. 111 url: The url of the data. 112 download: Whether to download the data if it is not saved at `path` yet. 113 checksum: The expected checksum of the data. 114 verify: Whether to verify the https address. 115 """ 116 if os.path.exists(path): 117 return 118 if not download: 119 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 120 121 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 122 r.raise_for_status() # check for error 123 file_size = int(r.headers.get("Content-Length", 0)) 124 desc = f"Download {url} to {path}" 125 if file_size == 0: 126 desc += " (unknown file size)" 127 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 128 copyfileobj(r_raw, f) 129 130 _check_checksum(path, checksum) 131 132 133def download_source_gdrive( 134 path: str, 135 url: str, 136 download: bool, 137 checksum: Optional[str] = None, 138 download_type: Literal["zip", "folder"] = "zip", 139 expected_samples: int = 10000, 140 quiet: bool = True, 141) -> None: 142 """Download data from google drive. 143 144 Args: 145 path: The path for saving the data. 146 url: The url of the data. 147 download: Whether to download the data if it is not saved at `path` yet. 148 checksum: The expected checksum of the data. 149 download_type: The download type, either 'zip' or 'folder'. 150 expected_samples: The maximal number of samples in the folder. 151 quiet: Whether to download quietly. 152 """ 153 if os.path.exists(path): 154 return 155 156 if not download: 157 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 158 159 if gdown is None: 160 raise RuntimeError( 161 "Need gdown library to download data from google drive. " 162 "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'." 163 ) 164 165 print("Downloading the files. Might take a few minutes...") 166 167 if download_type == "zip": 168 gdown.download(url, path, quiet=quiet) 169 _check_checksum(path, checksum) 170 elif download_type == "folder": 171 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'." 172 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 173 gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True) 174 else: 175 raise ValueError("`download_path` argument expects either `zip`/`folder`") 176 177 print("Download completed.") 178 179 180def download_source_empiar(path: str, access_id: str, download: bool) -> str: 181 """Download data from EMPIAR. 182 183 Requires the ascp command from the aspera CLI. 184 185 Args: 186 path: The path for saving the data. 187 access_id: The EMPIAR accession id of the data to download. 188 download: Whether to download the data if it is not saved at `path` yet. 189 190 Returns: 191 The path to the downloaded data. 192 """ 193 download_path = os.path.join(path, access_id) 194 195 if os.path.exists(download_path): 196 return download_path 197 if not download: 198 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 199 200 if which("ascp") is None: 201 raise RuntimeError( 202 "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'." 203 ) 204 205 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 206 if not os.path.exists(key_file): 207 conda_root = os.environ["CONDA_PREFIX"] 208 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 209 210 if not os.path.exists(key_file): 211 raise RuntimeError("Could not find the aspera ssh keyfile") 212 213 cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path] 214 run(cmd) 215 216 return download_path 217 218 219def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False): 220 """Download data from Kaggle. 221 222 Requires the Kaggle API. 223 224 Args: 225 path: The path for saving the data. 226 dataset_name: The name of the dataset to download. 227 download: Whether to download the data if it is not saved at `path` yet. 228 competition: Whether this data is from a competition and requires the kaggle.competition API. 229 """ 230 if not download: 231 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 232 233 try: 234 from kaggle.api.kaggle_api_extended import KaggleApi 235 except ModuleNotFoundError: 236 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 237 msg += "After you have installed kaggle, you would need an API token. " 238 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 239 raise ModuleNotFoundError(msg) 240 241 api = KaggleApi() 242 api.authenticate() 243 244 if competition: 245 api.competition_download_files(competition=dataset_name, path=path, quiet=False) 246 else: 247 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) 248 249 250def download_source_tcia(path, url, dst, csv_filename, download): 251 """Download data from TCIA. 252 253 Requires the tcia_utils python package. 254 255 Args: 256 path: The path for saving the data. 257 url: The URL to the TCIA dataset. 258 dst: 259 csv_filename: 260 download: Whether to download the data if it is not saved at `path` yet. 261 """ 262 if nbia is None: 263 raise RuntimeError("Requires the tcia_utils python package.") 264 if not download: 265 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 266 assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest." 267 268 # Downloads the manifest file from the collection page. 269 manifest = requests.get(url=url) 270 with open(path, "wb") as f: 271 f.write(manifest.content) 272 273 # This part extracts the UIDs from the manifests and downloads them. 274 nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename) 275 276 277def download_source_synapse(path: str, entity: str, download: bool) -> None: 278 """Download data from synapse. 279 280 Requires the synapseclient python library. 281 282 Args: 283 path: The path for saving the data. 284 entity: The name of the data to download from synapse. 285 download: Whether to download the data if it is not saved at `path` yet. 286 """ 287 if not download: 288 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 289 290 if synapseclient is None: 291 raise RuntimeError( 292 "You must install 'synapseclient' to download files from 'synapse'. " 293 "Remember to create an account and generate an authentication code for your account. " 294 "Please follow the documentation for details on creating the '~/.synapseConfig' file here: " 295 "https://python-docs.synapse.org/tutorials/authentication/." 296 ) 297 298 assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'." 299 300 # Download all files in the folder. 301 syn = synapseclient.Synapse() 302 syn.login() # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'. 303 synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path) 304 305 306def update_kwargs(kwargs, key, value, msg=None): 307 """@private 308 """ 309 if key in kwargs: 310 msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg 311 warn(msg) 312 kwargs[key] = value 313 return kwargs 314 315 316def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None: 317 """Unpack a tar archive. 318 319 Args: 320 tar_path: Path to the tar file. 321 dst: Where to unpack the archive. 322 remove: Whether to remove the tar file after unpacking. 323 """ 324 import tarfile 325 326 if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"): 327 access_mode = "r:gz" 328 elif tar_path.endswith(".tar"): 329 access_mode = "r:" 330 else: 331 raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.") 332 333 tar = tarfile.open(tar_path, access_mode) 334 tar.extractall(dst) 335 tar.close() 336 337 if remove: 338 os.remove(tar_path) 339 340 341def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None: 342 """Unpack a rar archive. 343 344 Args: 345 rar_path: Path to the rar file. 346 dst: Where to unpack the archive. 347 remove: Whether to remove the tar file after unpacking. 348 use_rarfile: Whether to use the rarfile library or aspose.zip. 349 """ 350 if use_rarfile: 351 import rarfile 352 with rarfile.RarFile(rar_path) as f: 353 f.extractall(path=dst) 354 else: 355 import aspose.zip as az 356 with az.rar.RarArchive(rar_path) as archive: 357 archive.extract_to_directory(dst) 358 359 if remove: 360 os.remove(rar_path) 361 362 363def unzip(zip_path: str, dst: str, remove: bool = True) -> None: 364 """Unpack a zip archive. 365 366 Args: 367 zip_path: Path to the zip file. 368 dst: Where to unpack the archive. 369 remove: Whether to remove the tar file after unpacking. 370 """ 371 with zipfile.ZipFile(zip_path, "r") as f: 372 f.extractall(dst) 373 if remove: 374 os.remove(zip_path) 375 376 377def split_kwargs(function, **kwargs): 378 """@private 379 """ 380 function_parameters = inspect.signature(function).parameters 381 parameter_names = list(function_parameters.keys()) 382 other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names} 383 kwargs = {k: v for k, v in kwargs.items() if k in parameter_names} 384 return kwargs, other_kwargs 385 386 387# this adds the default transforms for 'raw_transform' and 'transform' 388# in case these were not specified in the kwargs 389# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class 390# is used directly, e.g. in the LiveCell Loader 391def ensure_transforms(ndim, **kwargs): 392 """@private 393 """ 394 if "raw_transform" not in kwargs: 395 kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform()) 396 if "transform" not in kwargs: 397 kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim)) 398 return kwargs 399 400 401def add_instance_label_transform( 402 kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True, 403): 404 """@private 405 """ 406 if binary_is_exclusive: 407 assert sum((offsets is not None, boundaries, binary)) <= 1 408 else: 409 assert sum((offsets is not None, boundaries)) <= 1 410 if offsets is not None: 411 label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets, 412 add_binary_target=add_binary_target, 413 add_mask=True) 414 msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." 415 kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg) 416 label_dtype = torch.float32 417 elif boundaries: 418 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target) 419 msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 420 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 421 label_dtype = torch.float32 422 elif binary: 423 label_transform = torch_em.transform.label.labels_to_binary 424 msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 425 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 426 label_dtype = torch.float32 427 return kwargs, label_dtype 428 429 430def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None): 431 """@private 432 """ 433 # Checks for raw_transform and label_transform incoming values. 434 # If yes, it will automatically merge these two transforms to apply them together. 435 if resize_inputs: 436 assert isinstance(resize_kwargs, dict) 437 438 target_shape = resize_kwargs.get("patch_shape") 439 if len(resize_kwargs["patch_shape"]) == 3: 440 # we only need the XY dimensions to reshape the inputs along them. 441 target_shape = target_shape[1:] 442 # we provide the Z dimension value to return the desired number of slices and not the whole volume 443 kwargs["z_ext"] = resize_kwargs["patch_shape"][0] 444 445 raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"]) 446 label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True) 447 448 # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded. 449 patch_shape = None 450 451 if ensure_rgb is None: 452 raw_trafos = [] 453 else: 454 assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function." 455 raw_trafos = [ensure_rgb] 456 457 if "raw_transform" in kwargs: 458 raw_trafos.extend([raw_trafo, kwargs["raw_transform"]]) 459 else: 460 raw_trafos.extend([raw_trafo, get_raw_transform()]) 461 462 kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False) 463 464 if "label_transform" in kwargs: 465 trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False) 466 kwargs["label_transform"] = trafo 467 else: 468 kwargs["label_transform"] = label_trafo 469 470 return kwargs, patch_shape 471 472 473def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray: 474 """Generate a label mask from a contour defined in a xml annotation file. 475 476 Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 477 478 Args: 479 shape: The image shape. 480 xml_file: The path to the xml file with contour annotations. 481 482 Returns: 483 The label mask. 484 """ 485 # DOM object created by the minidom parser 486 xDoc = minidom.parse(xml_file) 487 488 # List of all Region tags 489 regions = xDoc.getElementsByTagName('Region') 490 491 # List which will store the vertices for each region 492 xy = [] 493 for region in regions: 494 # Loading all the vertices in the region 495 vertices = region.getElementsByTagName('Vertex') 496 497 # The vertices of a region will be stored in a array 498 vw = np.zeros((len(vertices), 2)) 499 500 for index, vertex in enumerate(vertices): 501 # Storing the values of x and y coordinate after conversion 502 vw[index][0] = float(vertex.getAttribute('X')) 503 vw[index][1] = float(vertex.getAttribute('Y')) 504 505 # Append the vertices of a region 506 xy.append(np.int32(vw)) 507 508 # Creating a completely black image 509 mask = np.zeros(shape, np.float32) 510 511 for i, contour in enumerate(xy): 512 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 513 mask[r, c] = i 514 return mask 515 516 517# This function could be extended to convert WSIs (or modalities with multiple resolutions). 518def convert_svs_to_array( 519 path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None, 520) -> np.ndarray: 521 """Convert a .svs file for WSI imagging to a numpy array. 522 523 Requires the tiffslide python library. 524 The function can load multi-resolution images. You can specify the resolution level via `level`. 525 526 Args: 527 path: File path ath to the svs file. 528 location: Pixel location (x, y) in level 0 of the image. 529 level: Target level used to read the image. 530 img_size: Size of the image. If None, the shape of the image at `level` is used. 531 532 Returns: 533 The image as numpy array. 534 """ 535 from tiffslide import TiffSlide 536 537 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 538 _slide = TiffSlide(path) 539 if img_size is None: 540 img_size = _slide.level_dimensions[0] 541 return _slide.read_region(location=location, level=level, size=img_size, as_array=True) 542 543 544def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str: 545 """Download data from the CryoET Data Portal. 546 547 Requires the cryoet-data-portal python library. 548 549 Args: 550 path: The path for saving the data. 551 dataset_id: The id of the data to download from the portal. 552 download: Whether to download the data if it is not saved at `path` yet. 553 554 Returns: 555 The file path to the downloaded data. 556 """ 557 if Client is None or Dataset is None: 558 raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'") 559 560 output_path = os.path.join(path, str(dataset_id)) 561 if os.path.exists(output_path): 562 return output_path 563 564 if not download: 565 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 566 567 client = Client() 568 dataset = Dataset.get_by_id(client, dataset_id) 569 dataset.download_everything(dest_path=path) 570 571 return output_path
77def get_checksum(filename: str) -> str: 78 """Get the SHA256 checksum of a file. 79 80 Args: 81 filename: The filepath. 82 83 Returns: 84 The checksum. 85 """ 86 with open(filename, "rb") as f: 87 file_ = f.read() 88 checksum = hashlib.sha256(file_).hexdigest() 89 return checksum
Get the SHA256 checksum of a file.
Arguments:
- filename: The filepath.
Returns:
The checksum.
107def download_source(path: str, url: str, download: bool, checksum: Optional[str] = None, verify: bool = True) -> None: 108 """Download data via https. 109 110 Args: 111 path: The path for saving the data. 112 url: The url of the data. 113 download: Whether to download the data if it is not saved at `path` yet. 114 checksum: The expected checksum of the data. 115 verify: Whether to verify the https address. 116 """ 117 if os.path.exists(path): 118 return 119 if not download: 120 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 121 122 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 123 r.raise_for_status() # check for error 124 file_size = int(r.headers.get("Content-Length", 0)) 125 desc = f"Download {url} to {path}" 126 if file_size == 0: 127 desc += " (unknown file size)" 128 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 129 copyfileobj(r_raw, f) 130 131 _check_checksum(path, checksum)
Download data via https.
Arguments:
- path: The path for saving the data.
- url: The url of the data.
- download: Whether to download the data if it is not saved at
path
yet. - checksum: The expected checksum of the data.
- verify: Whether to verify the https address.
134def download_source_gdrive( 135 path: str, 136 url: str, 137 download: bool, 138 checksum: Optional[str] = None, 139 download_type: Literal["zip", "folder"] = "zip", 140 expected_samples: int = 10000, 141 quiet: bool = True, 142) -> None: 143 """Download data from google drive. 144 145 Args: 146 path: The path for saving the data. 147 url: The url of the data. 148 download: Whether to download the data if it is not saved at `path` yet. 149 checksum: The expected checksum of the data. 150 download_type: The download type, either 'zip' or 'folder'. 151 expected_samples: The maximal number of samples in the folder. 152 quiet: Whether to download quietly. 153 """ 154 if os.path.exists(path): 155 return 156 157 if not download: 158 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 159 160 if gdown is None: 161 raise RuntimeError( 162 "Need gdown library to download data from google drive. " 163 "Please install gdown: 'conda install -c conda-forge gdown==4.6.3'." 164 ) 165 166 print("Downloading the files. Might take a few minutes...") 167 168 if download_type == "zip": 169 gdown.download(url, path, quiet=quiet) 170 _check_checksum(path, checksum) 171 elif download_type == "folder": 172 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install 'gdown==4.6.3'." 173 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 174 gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True) 175 else: 176 raise ValueError("`download_path` argument expects either `zip`/`folder`") 177 178 print("Download completed.")
Download data from google drive.
Arguments:
- path: The path for saving the data.
- url: The url of the data.
- download: Whether to download the data if it is not saved at
path
yet. - checksum: The expected checksum of the data.
- download_type: The download type, either 'zip' or 'folder'.
- expected_samples: The maximal number of samples in the folder.
- quiet: Whether to download quietly.
181def download_source_empiar(path: str, access_id: str, download: bool) -> str: 182 """Download data from EMPIAR. 183 184 Requires the ascp command from the aspera CLI. 185 186 Args: 187 path: The path for saving the data. 188 access_id: The EMPIAR accession id of the data to download. 189 download: Whether to download the data if it is not saved at `path` yet. 190 191 Returns: 192 The path to the downloaded data. 193 """ 194 download_path = os.path.join(path, access_id) 195 196 if os.path.exists(download_path): 197 return download_path 198 if not download: 199 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 200 201 if which("ascp") is None: 202 raise RuntimeError( 203 "Need aspera-cli to download data from empiar. You can install it via 'conda install -c hcc aspera-cli'." 204 ) 205 206 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 207 if not os.path.exists(key_file): 208 conda_root = os.environ["CONDA_PREFIX"] 209 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 210 211 if not os.path.exists(key_file): 212 raise RuntimeError("Could not find the aspera ssh keyfile") 213 214 cmd = ["ascp", "-QT", "-l", "200M", "-P33001", "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path] 215 run(cmd) 216 217 return download_path
Download data from EMPIAR.
Requires the ascp command from the aspera CLI.
Arguments:
- path: The path for saving the data.
- access_id: The EMPIAR accession id of the data to download.
- download: Whether to download the data if it is not saved at
path
yet.
Returns:
The path to the downloaded data.
220def download_source_kaggle(path: str, dataset_name: str, download: bool, competition: bool = False): 221 """Download data from Kaggle. 222 223 Requires the Kaggle API. 224 225 Args: 226 path: The path for saving the data. 227 dataset_name: The name of the dataset to download. 228 download: Whether to download the data if it is not saved at `path` yet. 229 competition: Whether this data is from a competition and requires the kaggle.competition API. 230 """ 231 if not download: 232 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 233 234 try: 235 from kaggle.api.kaggle_api_extended import KaggleApi 236 except ModuleNotFoundError: 237 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 238 msg += "After you have installed kaggle, you would need an API token. " 239 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 240 raise ModuleNotFoundError(msg) 241 242 api = KaggleApi() 243 api.authenticate() 244 245 if competition: 246 api.competition_download_files(competition=dataset_name, path=path, quiet=False) 247 else: 248 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
Download data from Kaggle.
Requires the Kaggle API.
Arguments:
- path: The path for saving the data.
- dataset_name: The name of the dataset to download.
- download: Whether to download the data if it is not saved at
path
yet. - competition: Whether this data is from a competition and requires the kaggle.competition API.
251def download_source_tcia(path, url, dst, csv_filename, download): 252 """Download data from TCIA. 253 254 Requires the tcia_utils python package. 255 256 Args: 257 path: The path for saving the data. 258 url: The URL to the TCIA dataset. 259 dst: 260 csv_filename: 261 download: Whether to download the data if it is not saved at `path` yet. 262 """ 263 if nbia is None: 264 raise RuntimeError("Requires the tcia_utils python package.") 265 if not download: 266 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 267 assert url.endswith(".tcia"), f"{url} is not a TCIA Manifest." 268 269 # Downloads the manifest file from the collection page. 270 manifest = requests.get(url=url) 271 with open(path, "wb") as f: 272 f.write(manifest.content) 273 274 # This part extracts the UIDs from the manifests and downloads them. 275 nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename)
Download data from TCIA.
Requires the tcia_utils python package.
Arguments:
- path: The path for saving the data.
- url: The URL to the TCIA dataset.
- dst:
- csv_filename:
- download: Whether to download the data if it is not saved at
path
yet.
278def download_source_synapse(path: str, entity: str, download: bool) -> None: 279 """Download data from synapse. 280 281 Requires the synapseclient python library. 282 283 Args: 284 path: The path for saving the data. 285 entity: The name of the data to download from synapse. 286 download: Whether to download the data if it is not saved at `path` yet. 287 """ 288 if not download: 289 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 290 291 if synapseclient is None: 292 raise RuntimeError( 293 "You must install 'synapseclient' to download files from 'synapse'. " 294 "Remember to create an account and generate an authentication code for your account. " 295 "Please follow the documentation for details on creating the '~/.synapseConfig' file here: " 296 "https://python-docs.synapse.org/tutorials/authentication/." 297 ) 298 299 assert entity.startswith("syn"), "The entity name does not look as expected. It should be something like 'syn123'." 300 301 # Download all files in the folder. 302 syn = synapseclient.Synapse() 303 syn.login() # Since we do not pass any credentials here, it fetches all details from '~/.synapseConfig'. 304 synapseutils.syncFromSynapse(syn=syn, entity=entity, path=path)
Download data from synapse.
Requires the synapseclient python library.
Arguments:
- path: The path for saving the data.
- entity: The name of the data to download from synapse.
- download: Whether to download the data if it is not saved at
path
yet.
317def unzip_tarfile(tar_path: str, dst: str, remove: bool = True) -> None: 318 """Unpack a tar archive. 319 320 Args: 321 tar_path: Path to the tar file. 322 dst: Where to unpack the archive. 323 remove: Whether to remove the tar file after unpacking. 324 """ 325 import tarfile 326 327 if tar_path.endswith(".tar.gz") or tar_path.endswith(".tgz"): 328 access_mode = "r:gz" 329 elif tar_path.endswith(".tar"): 330 access_mode = "r:" 331 else: 332 raise ValueError(f"The provided file isn't a supported archive to unpack. Please check the file: {tar_path}.") 333 334 tar = tarfile.open(tar_path, access_mode) 335 tar.extractall(dst) 336 tar.close() 337 338 if remove: 339 os.remove(tar_path)
Unpack a tar archive.
Arguments:
- tar_path: Path to the tar file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
342def unzip_rarfile(rar_path: str, dst: str, remove: bool = True, use_rarfile: bool = True) -> None: 343 """Unpack a rar archive. 344 345 Args: 346 rar_path: Path to the rar file. 347 dst: Where to unpack the archive. 348 remove: Whether to remove the tar file after unpacking. 349 use_rarfile: Whether to use the rarfile library or aspose.zip. 350 """ 351 if use_rarfile: 352 import rarfile 353 with rarfile.RarFile(rar_path) as f: 354 f.extractall(path=dst) 355 else: 356 import aspose.zip as az 357 with az.rar.RarArchive(rar_path) as archive: 358 archive.extract_to_directory(dst) 359 360 if remove: 361 os.remove(rar_path)
Unpack a rar archive.
Arguments:
- rar_path: Path to the rar file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
- use_rarfile: Whether to use the rarfile library or aspose.zip.
364def unzip(zip_path: str, dst: str, remove: bool = True) -> None: 365 """Unpack a zip archive. 366 367 Args: 368 zip_path: Path to the zip file. 369 dst: Where to unpack the archive. 370 remove: Whether to remove the tar file after unpacking. 371 """ 372 with zipfile.ZipFile(zip_path, "r") as f: 373 f.extractall(dst) 374 if remove: 375 os.remove(zip_path)
Unpack a zip archive.
Arguments:
- zip_path: Path to the zip file.
- dst: Where to unpack the archive.
- remove: Whether to remove the tar file after unpacking.
474def generate_labeled_array_from_xml(shape: Tuple[int, ...], xml_file: str) -> np.ndarray: 475 """Generate a label mask from a contour defined in a xml annotation file. 476 477 Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 478 479 Args: 480 shape: The image shape. 481 xml_file: The path to the xml file with contour annotations. 482 483 Returns: 484 The label mask. 485 """ 486 # DOM object created by the minidom parser 487 xDoc = minidom.parse(xml_file) 488 489 # List of all Region tags 490 regions = xDoc.getElementsByTagName('Region') 491 492 # List which will store the vertices for each region 493 xy = [] 494 for region in regions: 495 # Loading all the vertices in the region 496 vertices = region.getElementsByTagName('Vertex') 497 498 # The vertices of a region will be stored in a array 499 vw = np.zeros((len(vertices), 2)) 500 501 for index, vertex in enumerate(vertices): 502 # Storing the values of x and y coordinate after conversion 503 vw[index][0] = float(vertex.getAttribute('X')) 504 vw[index][1] = float(vertex.getAttribute('Y')) 505 506 # Append the vertices of a region 507 xy.append(np.int32(vw)) 508 509 # Creating a completely black image 510 mask = np.zeros(shape, np.float32) 511 512 for i, contour in enumerate(xy): 513 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 514 mask[r, c] = i 515 return mask
Generate a label mask from a contour defined in a xml annotation file.
Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
Arguments:
- shape: The image shape.
- xml_file: The path to the xml file with contour annotations.
Returns:
The label mask.
519def convert_svs_to_array( 520 path: str, location: Tuple[int, int] = (0, 0), level: int = 0, img_size: Tuple[int, int] = None, 521) -> np.ndarray: 522 """Convert a .svs file for WSI imagging to a numpy array. 523 524 Requires the tiffslide python library. 525 The function can load multi-resolution images. You can specify the resolution level via `level`. 526 527 Args: 528 path: File path ath to the svs file. 529 location: Pixel location (x, y) in level 0 of the image. 530 level: Target level used to read the image. 531 img_size: Size of the image. If None, the shape of the image at `level` is used. 532 533 Returns: 534 The image as numpy array. 535 """ 536 from tiffslide import TiffSlide 537 538 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 539 _slide = TiffSlide(path) 540 if img_size is None: 541 img_size = _slide.level_dimensions[0] 542 return _slide.read_region(location=location, level=level, size=img_size, as_array=True)
Convert a .svs file for WSI imagging to a numpy array.
Requires the tiffslide python library.
The function can load multi-resolution images. You can specify the resolution level via level
.
Arguments:
- path: File path ath to the svs file.
- location: Pixel location (x, y) in level 0 of the image.
- level: Target level used to read the image.
- img_size: Size of the image. If None, the shape of the image at
level
is used.
Returns:
The image as numpy array.
545def download_from_cryo_et_portal(path: str, dataset_id: int, download: bool) -> str: 546 """Download data from the CryoET Data Portal. 547 548 Requires the cryoet-data-portal python library. 549 550 Args: 551 path: The path for saving the data. 552 dataset_id: The id of the data to download from the portal. 553 download: Whether to download the data if it is not saved at `path` yet. 554 555 Returns: 556 The file path to the downloaded data. 557 """ 558 if Client is None or Dataset is None: 559 raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'") 560 561 output_path = os.path.join(path, str(dataset_id)) 562 if os.path.exists(output_path): 563 return output_path 564 565 if not download: 566 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False.") 567 568 client = Client() 569 dataset = Dataset.get_by_id(client, dataset_id) 570 dataset.download_everything(dest_path=path) 571 572 return output_path
Download data from the CryoET Data Portal.
Requires the cryoet-data-portal python library.
Arguments:
- path: The path for saving the data.
- dataset_id: The id of the data to download from the portal.
- download: Whether to download the data if it is not saved at
path
yet.
Returns:
The file path to the downloaded data.