torch_em.data.datasets.util
1import inspect 2import os 3import hashlib 4import zipfile 5import numpy as np 6from tqdm import tqdm 7from warnings import warn 8from xml.dom import minidom 9from shutil import copyfileobj, which 10from subprocess import run 11from packaging import version 12 13from skimage.draw import polygon 14 15import torch 16import torch_em 17import requests 18 19try: 20 import gdown 21except ImportError: 22 gdown = None 23 24 25BIOIMAGEIO_IDS = { 26 "covid_if": "ilastik/covid_if_training_data", 27 "cremi": "ilastik/cremi_training_data", 28 "dsb": "ilastik/stardist_dsb_training_data", 29 "hpa": "", # not on bioimageio yet 30 "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge", 31 "kasthuri": "", # not on bioimageio yet: 32 "livecell": "ilastik/livecell_dataset", 33 "lucchi": "", # not on bioimageio yet: 34 "mitoem": "ilastik/mitoem_segmentation_challenge", 35 "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018", 36 "ovules": "", # not on bioimageio yet 37 "plantseg_root": "ilastik/plantseg_root", 38 "plantseg_ovules": "ilastik/plantseg_ovules", 39 "platynereis": "ilastik/platynereis_em_training_data", 40 "snemi": "", # not on bioimagegio yet 41 "uro_cell": "", # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693 42 "vnc": "ilastik/vnc", 43} 44 45 46def get_bioimageio_dataset_id(dataset_name): 47 assert dataset_name in BIOIMAGEIO_IDS 48 return BIOIMAGEIO_IDS[dataset_name] 49 50 51def get_checksum(filename): 52 with open(filename, "rb") as f: 53 file_ = f.read() 54 checksum = hashlib.sha256(file_).hexdigest() 55 return checksum 56 57 58def _check_checksum(path, checksum): 59 if checksum is not None: 60 this_checksum = get_checksum(path) 61 if this_checksum != checksum: 62 raise RuntimeError( 63 "The checksum of the download does not match the expected checksum." 64 f"Expected: {checksum}, got: {this_checksum}" 65 ) 66 print("Download successful and checksums agree.") 67 else: 68 warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.") 69 70 71# this needs to be extended to support download from s3 via boto, 72# if we get a resource that is available via s3 without support for http 73def download_source(path, url, download, checksum=None, verify=True): 74 if os.path.exists(path): 75 return 76 if not download: 77 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 78 79 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 80 r.raise_for_status() # check for error 81 file_size = int(r.headers.get("Content-Length", 0)) 82 desc = f"Download {url} to {path}" 83 if file_size == 0: 84 desc += " (unknown file size)" 85 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 86 copyfileobj(r_raw, f) 87 88 _check_checksum(path, checksum) 89 90 91def download_source_gdrive(path, url, download, checksum=None, download_type="zip", expected_samples=10000): 92 if os.path.exists(path): 93 return 94 95 if not download: 96 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 97 98 if gdown is None: 99 raise RuntimeError( 100 "Need gdown library to download data from google drive." 101 "Please install gdown and then rerun." 102 ) 103 104 print("Downloading the dataset. Might take a few minutes...") 105 106 if download_type == "zip": 107 gdown.download(url, path, quiet=False) 108 _check_checksum(path, checksum) 109 elif download_type == "folder": 110 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install `gdown==4.6.3`." 111 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 112 gdown.download_folder(url=url, output=path, quiet=True, remaining_ok=True) 113 else: 114 raise ValueError("`download_path` argument expects either `zip`/`folder`") 115 print("Download completed.") 116 117 118def download_source_empiar(path, access_id, download): 119 download_path = os.path.join(path, access_id) 120 121 if os.path.exists(download_path): 122 return download_path 123 if not download: 124 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 125 126 if which("ascp") is None: 127 raise RuntimeError( 128 "Need aspera-cli to download data from empiar." 129 "You can install it via 'mamba install -c hcc aspera-cli'." 130 ) 131 132 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 133 if not os.path.exists(key_file): 134 conda_root = os.environ["CONDA_PREFIX"] 135 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 136 137 if not os.path.exists(key_file): 138 raise RuntimeError("Could not find the aspera ssh keyfile") 139 140 cmd = [ 141 "ascp", "-QT", "-l", "200M", "-P33001", 142 "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path 143 ] 144 run(cmd) 145 146 return download_path 147 148 149def download_source_kaggle(path, dataset_name, download): 150 if not download: 151 raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") 152 153 try: 154 from kaggle.api.kaggle_api_extended import KaggleApi 155 except ModuleNotFoundError: 156 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 157 msg += "After you have installed kaggle, you would need an API token. " 158 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 159 raise ModuleNotFoundError(msg) 160 161 api = KaggleApi() 162 api.authenticate() 163 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) 164 165 166def update_kwargs(kwargs, key, value, msg=None): 167 if key in kwargs: 168 msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg 169 warn(msg) 170 kwargs[key] = value 171 return kwargs 172 173 174def unzip(zip_path, dst, remove=True): 175 with zipfile.ZipFile(zip_path, "r") as f: 176 f.extractall(dst) 177 if remove: 178 os.remove(zip_path) 179 180 181def split_kwargs(function, **kwargs): 182 function_parameters = inspect.signature(function).parameters 183 parameter_names = list(function_parameters.keys()) 184 other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names} 185 kwargs = {k: v for k, v in kwargs.items() if k in parameter_names} 186 return kwargs, other_kwargs 187 188 189# this adds the default transforms for 'raw_transform' and 'transform' 190# in case these were not specified in the kwargs 191# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class 192# is used directly, e.g. in the LiveCell Loader 193def ensure_transforms(ndim, **kwargs): 194 if "raw_transform" not in kwargs: 195 kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform()) 196 if "transform" not in kwargs: 197 kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim)) 198 return kwargs 199 200 201def add_instance_label_transform( 202 kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True, 203): 204 if binary_is_exclusive: 205 assert sum((offsets is not None, boundaries, binary)) <= 1 206 else: 207 assert sum((offsets is not None, boundaries)) <= 1 208 if offsets is not None: 209 label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets, 210 add_binary_target=add_binary_target, 211 add_mask=True) 212 msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." 213 kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg) 214 label_dtype = torch.float32 215 elif boundaries: 216 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target) 217 msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 218 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 219 label_dtype = torch.float32 220 elif binary: 221 label_transform = torch_em.transform.label.labels_to_binary 222 msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 223 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 224 label_dtype = torch.float32 225 return kwargs, label_dtype 226 227 228def generate_labeled_array_from_xml(shape, xml_file): 229 """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 230 231 Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white 232 shape: The image shape on which bit mask will be made 233 xml_file: path relative to the current working directory where the xml file is present 234 235 Returns: 236 An image of given shape with region inside contour being white.. 237 """ 238 # DOM object created by the minidom parser 239 xDoc = minidom.parse(xml_file) 240 241 # List of all Region tags 242 regions = xDoc.getElementsByTagName('Region') 243 244 # List which will store the vertices for each region 245 xy = [] 246 for region in regions: 247 # Loading all the vertices in the region 248 vertices = region.getElementsByTagName('Vertex') 249 250 # The vertices of a region will be stored in a array 251 vw = np.zeros((len(vertices), 2)) 252 253 for index, vertex in enumerate(vertices): 254 # Storing the values of x and y coordinate after conversion 255 vw[index][0] = float(vertex.getAttribute('X')) 256 vw[index][1] = float(vertex.getAttribute('Y')) 257 258 # Append the vertices of a region 259 xy.append(np.int32(vw)) 260 261 # Creating a completely black image 262 mask = np.zeros(shape, np.float32) 263 264 for i, contour in enumerate(xy): 265 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 266 mask[r, c] = i 267 return mask 268 269 270def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None): 271 """Converts .svs files to numpy array format 272 273 Argument: 274 - path: [str] - Path to the svs file 275 (below mentioned arguments are used for multi-resolution images) 276 - location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0)) 277 - level: [int] - target level used to read the image (default: 0) 278 - img_size: tuple[int, int] - expected size of the image 279 (default: None -> obtains the original shape at the expected level) 280 281 Returns: 282 the image as numpy array 283 284 TODO: it can be extended to convert WSIs (or modalities with multiple resolutions) 285 """ 286 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 287 288 from tiffslide import TiffSlide 289 290 _slide = TiffSlide(path) 291 292 if img_size is None: 293 img_size = _slide.level_dimensions[0] 294 295 img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True) 296 297 return img_arr
BIOIMAGEIO_IDS =
{'covid_if': 'ilastik/covid_if_training_data', 'cremi': 'ilastik/cremi_training_data', 'dsb': 'ilastik/stardist_dsb_training_data', 'hpa': '', 'isbi2012': 'ilastik/isbi2012_neuron_segmentation_challenge', 'kasthuri': '', 'livecell': 'ilastik/livecell_dataset', 'lucchi': '', 'mitoem': 'ilastik/mitoem_segmentation_challenge', 'monuseg': 'deepimagej/monuseg_digital_pathology_miccai2018', 'ovules': '', 'plantseg_root': 'ilastik/plantseg_root', 'plantseg_ovules': 'ilastik/plantseg_ovules', 'platynereis': 'ilastik/platynereis_em_training_data', 'snemi': '', 'uro_cell': '', 'vnc': 'ilastik/vnc'}
def
get_bioimageio_dataset_id(dataset_name):
def
get_checksum(filename):
def
download_source(path, url, download, checksum=None, verify=True):
74def download_source(path, url, download, checksum=None, verify=True): 75 if os.path.exists(path): 76 return 77 if not download: 78 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 79 80 with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r: 81 r.raise_for_status() # check for error 82 file_size = int(r.headers.get("Content-Length", 0)) 83 desc = f"Download {url} to {path}" 84 if file_size == 0: 85 desc += " (unknown file size)" 86 with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 87 copyfileobj(r_raw, f) 88 89 _check_checksum(path, checksum)
def
download_source_gdrive( path, url, download, checksum=None, download_type='zip', expected_samples=10000):
92def download_source_gdrive(path, url, download, checksum=None, download_type="zip", expected_samples=10000): 93 if os.path.exists(path): 94 return 95 96 if not download: 97 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 98 99 if gdown is None: 100 raise RuntimeError( 101 "Need gdown library to download data from google drive." 102 "Please install gdown and then rerun." 103 ) 104 105 print("Downloading the dataset. Might take a few minutes...") 106 107 if download_type == "zip": 108 gdown.download(url, path, quiet=False) 109 _check_checksum(path, checksum) 110 elif download_type == "folder": 111 assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install `gdown==4.6.3`." 112 gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples 113 gdown.download_folder(url=url, output=path, quiet=True, remaining_ok=True) 114 else: 115 raise ValueError("`download_path` argument expects either `zip`/`folder`") 116 print("Download completed.")
def
download_source_empiar(path, access_id, download):
119def download_source_empiar(path, access_id, download): 120 download_path = os.path.join(path, access_id) 121 122 if os.path.exists(download_path): 123 return download_path 124 if not download: 125 raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") 126 127 if which("ascp") is None: 128 raise RuntimeError( 129 "Need aspera-cli to download data from empiar." 130 "You can install it via 'mamba install -c hcc aspera-cli'." 131 ) 132 133 key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") 134 if not os.path.exists(key_file): 135 conda_root = os.environ["CONDA_PREFIX"] 136 key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") 137 138 if not os.path.exists(key_file): 139 raise RuntimeError("Could not find the aspera ssh keyfile") 140 141 cmd = [ 142 "ascp", "-QT", "-l", "200M", "-P33001", 143 "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path 144 ] 145 run(cmd) 146 147 return download_path
def
download_source_kaggle(path, dataset_name, download):
150def download_source_kaggle(path, dataset_name, download): 151 if not download: 152 raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") 153 154 try: 155 from kaggle.api.kaggle_api_extended import KaggleApi 156 except ModuleNotFoundError: 157 msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " 158 msg += "After you have installed kaggle, you would need an API token. " 159 msg += "Follow the instructions at https://www.kaggle.com/docs/api." 160 raise ModuleNotFoundError(msg) 161 162 api = KaggleApi() 163 api.authenticate() 164 api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
def
update_kwargs(kwargs, key, value, msg=None):
def
unzip(zip_path, dst, remove=True):
def
split_kwargs(function, **kwargs):
182def split_kwargs(function, **kwargs): 183 function_parameters = inspect.signature(function).parameters 184 parameter_names = list(function_parameters.keys()) 185 other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names} 186 kwargs = {k: v for k, v in kwargs.items() if k in parameter_names} 187 return kwargs, other_kwargs
def
ensure_transforms(ndim, **kwargs):
194def ensure_transforms(ndim, **kwargs): 195 if "raw_transform" not in kwargs: 196 kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform()) 197 if "transform" not in kwargs: 198 kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim)) 199 return kwargs
def
add_instance_label_transform( kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True):
202def add_instance_label_transform( 203 kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True, 204): 205 if binary_is_exclusive: 206 assert sum((offsets is not None, boundaries, binary)) <= 1 207 else: 208 assert sum((offsets is not None, boundaries)) <= 1 209 if offsets is not None: 210 label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets, 211 add_binary_target=add_binary_target, 212 add_mask=True) 213 msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." 214 kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg) 215 label_dtype = torch.float32 216 elif boundaries: 217 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target) 218 msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 219 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 220 label_dtype = torch.float32 221 elif binary: 222 label_transform = torch_em.transform.label.labels_to_binary 223 msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." 224 kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg) 225 label_dtype = torch.float32 226 return kwargs, label_dtype
def
generate_labeled_array_from_xml(shape, xml_file):
229def generate_labeled_array_from_xml(shape, xml_file): 230 """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb 231 232 Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white 233 shape: The image shape on which bit mask will be made 234 xml_file: path relative to the current working directory where the xml file is present 235 236 Returns: 237 An image of given shape with region inside contour being white.. 238 """ 239 # DOM object created by the minidom parser 240 xDoc = minidom.parse(xml_file) 241 242 # List of all Region tags 243 regions = xDoc.getElementsByTagName('Region') 244 245 # List which will store the vertices for each region 246 xy = [] 247 for region in regions: 248 # Loading all the vertices in the region 249 vertices = region.getElementsByTagName('Vertex') 250 251 # The vertices of a region will be stored in a array 252 vw = np.zeros((len(vertices), 2)) 253 254 for index, vertex in enumerate(vertices): 255 # Storing the values of x and y coordinate after conversion 256 vw[index][0] = float(vertex.getAttribute('X')) 257 vw[index][1] = float(vertex.getAttribute('Y')) 258 259 # Append the vertices of a region 260 xy.append(np.int32(vw)) 261 262 # Creating a completely black image 263 mask = np.zeros(shape, np.float32) 264 265 for i, contour in enumerate(xy): 266 r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape) 267 mask[r, c] = i 268 return mask
Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white shape: The image shape on which bit mask will be made xml_file: path relative to the current working directory where the xml file is present
Returns:
An image of given shape with region inside contour being white..
def
convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
271def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None): 272 """Converts .svs files to numpy array format 273 274 Argument: 275 - path: [str] - Path to the svs file 276 (below mentioned arguments are used for multi-resolution images) 277 - location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0)) 278 - level: [int] - target level used to read the image (default: 0) 279 - img_size: tuple[int, int] - expected size of the image 280 (default: None -> obtains the original shape at the expected level) 281 282 Returns: 283 the image as numpy array 284 285 TODO: it can be extended to convert WSIs (or modalities with multiple resolutions) 286 """ 287 assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format" 288 289 from tiffslide import TiffSlide 290 291 _slide = TiffSlide(path) 292 293 if img_size is None: 294 img_size = _slide.level_dimensions[0] 295 296 img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True) 297 298 return img_arr
Converts .svs files to numpy array format
Argument:
- path: [str] - Path to the svs file (below mentioned arguments are used for multi-resolution images)
- location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
- level: [int] - target level used to read the image (default: 0)
- img_size: tuple[int, int] - expected size of the image (default: None -> obtains the original shape at the expected level)
Returns:
the image as numpy array
TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)