torch_em.data.datasets.util

  1import inspect
  2import os
  3import hashlib
  4import zipfile
  5import numpy as np
  6from tqdm import tqdm
  7from warnings import warn
  8from xml.dom import minidom
  9from shutil import copyfileobj, which
 10from subprocess import run
 11from packaging import version
 12
 13from skimage.draw import polygon
 14
 15import torch
 16import torch_em
 17import requests
 18
 19try:
 20    import gdown
 21except ImportError:
 22    gdown = None
 23
 24
 25BIOIMAGEIO_IDS = {
 26    "covid_if": "ilastik/covid_if_training_data",
 27    "cremi": "ilastik/cremi_training_data",
 28    "dsb": "ilastik/stardist_dsb_training_data",
 29    "hpa": "",  # not on bioimageio yet
 30    "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge",
 31    "kasthuri": "",  # not on bioimageio yet:
 32    "livecell": "ilastik/livecell_dataset",
 33    "lucchi": "",  # not on bioimageio yet:
 34    "mitoem": "ilastik/mitoem_segmentation_challenge",
 35    "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018",
 36    "ovules": "",  # not on bioimageio yet
 37    "plantseg_root": "ilastik/plantseg_root",
 38    "plantseg_ovules": "ilastik/plantseg_ovules",
 39    "platynereis": "ilastik/platynereis_em_training_data",
 40    "snemi": "",  # not on bioimagegio yet
 41    "uro_cell": "",  # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693
 42    "vnc": "ilastik/vnc",
 43}
 44
 45
 46def get_bioimageio_dataset_id(dataset_name):
 47    assert dataset_name in BIOIMAGEIO_IDS
 48    return BIOIMAGEIO_IDS[dataset_name]
 49
 50
 51def get_checksum(filename):
 52    with open(filename, "rb") as f:
 53        file_ = f.read()
 54        checksum = hashlib.sha256(file_).hexdigest()
 55    return checksum
 56
 57
 58def _check_checksum(path, checksum):
 59    if checksum is not None:
 60        this_checksum = get_checksum(path)
 61        if this_checksum != checksum:
 62            raise RuntimeError(
 63                "The checksum of the download does not match the expected checksum."
 64                f"Expected: {checksum}, got: {this_checksum}"
 65            )
 66        print("Download successful and checksums agree.")
 67    else:
 68        warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.")
 69
 70
 71# this needs to be extended to support download from s3 via boto,
 72# if we get a resource that is available via s3 without support for http
 73def download_source(path, url, download, checksum=None, verify=True):
 74    if os.path.exists(path):
 75        return
 76    if not download:
 77        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
 78
 79    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
 80        r.raise_for_status()  # check for error
 81        file_size = int(r.headers.get("Content-Length", 0))
 82        desc = f"Download {url} to {path}"
 83        if file_size == 0:
 84            desc += " (unknown file size)"
 85        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
 86            copyfileobj(r_raw, f)
 87
 88    _check_checksum(path, checksum)
 89
 90
 91def download_source_gdrive(path, url, download, checksum=None, download_type="zip", expected_samples=10000):
 92    if os.path.exists(path):
 93        return
 94
 95    if not download:
 96        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
 97
 98    if gdown is None:
 99        raise RuntimeError(
100            "Need gdown library to download data from google drive."
101            "Please install gdown and then rerun."
102        )
103
104    print("Downloading the dataset. Might take a few minutes...")
105
106    if download_type == "zip":
107        gdown.download(url, path, quiet=False)
108        _check_checksum(path, checksum)
109    elif download_type == "folder":
110        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install `gdown==4.6.3`."
111        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
112        gdown.download_folder(url=url, output=path, quiet=True, remaining_ok=True)
113    else:
114        raise ValueError("`download_path` argument expects either `zip`/`folder`")
115    print("Download completed.")
116
117
118def download_source_empiar(path, access_id, download):
119    download_path = os.path.join(path, access_id)
120
121    if os.path.exists(download_path):
122        return download_path
123    if not download:
124        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
125
126    if which("ascp") is None:
127        raise RuntimeError(
128            "Need aspera-cli to download data from empiar."
129            "You can install it via 'mamba install -c hcc aspera-cli'."
130        )
131
132    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
133    if not os.path.exists(key_file):
134        conda_root = os.environ["CONDA_PREFIX"]
135        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
136
137    if not os.path.exists(key_file):
138        raise RuntimeError("Could not find the aspera ssh keyfile")
139
140    cmd = [
141        "ascp", "-QT", "-l", "200M", "-P33001",
142        "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path
143    ]
144    run(cmd)
145
146    return download_path
147
148
149def download_source_kaggle(path, dataset_name, download):
150    if not download:
151        raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")
152
153    try:
154        from kaggle.api.kaggle_api_extended import KaggleApi
155    except ModuleNotFoundError:
156        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
157        msg += "After you have installed kaggle, you would need an API token. "
158        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
159        raise ModuleNotFoundError(msg)
160
161    api = KaggleApi()
162    api.authenticate()
163    api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
164
165
166def update_kwargs(kwargs, key, value, msg=None):
167    if key in kwargs:
168        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
169        warn(msg)
170    kwargs[key] = value
171    return kwargs
172
173
174def unzip(zip_path, dst, remove=True):
175    with zipfile.ZipFile(zip_path, "r") as f:
176        f.extractall(dst)
177    if remove:
178        os.remove(zip_path)
179
180
181def split_kwargs(function, **kwargs):
182    function_parameters = inspect.signature(function).parameters
183    parameter_names = list(function_parameters.keys())
184    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
185    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
186    return kwargs, other_kwargs
187
188
189# this adds the default transforms for 'raw_transform' and 'transform'
190# in case these were not specified in the kwargs
191# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class
192# is used directly, e.g. in the LiveCell Loader
193def ensure_transforms(ndim, **kwargs):
194    if "raw_transform" not in kwargs:
195        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
196    if "transform" not in kwargs:
197        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
198    return kwargs
199
200
201def add_instance_label_transform(
202    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
203):
204    if binary_is_exclusive:
205        assert sum((offsets is not None, boundaries, binary)) <= 1
206    else:
207        assert sum((offsets is not None, boundaries)) <= 1
208    if offsets is not None:
209        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
210                                                                      add_binary_target=add_binary_target,
211                                                                      add_mask=True)
212        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
213        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
214        label_dtype = torch.float32
215    elif boundaries:
216        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
217        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
218        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
219        label_dtype = torch.float32
220    elif binary:
221        label_transform = torch_em.transform.label.labels_to_binary
222        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
223        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
224        label_dtype = torch.float32
225    return kwargs, label_dtype
226
227
228def generate_labeled_array_from_xml(shape, xml_file):
229    """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
230
231    Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white
232        shape: The image shape on which bit mask will be made
233        xml_file: path relative to the current working directory where the xml file is present
234
235    Returns:
236        An image of given shape with region inside contour being white..
237    """
238    # DOM object created by the minidom parser
239    xDoc = minidom.parse(xml_file)
240
241    # List of all Region tags
242    regions = xDoc.getElementsByTagName('Region')
243
244    # List which will store the vertices for each region
245    xy = []
246    for region in regions:
247        # Loading all the vertices in the region
248        vertices = region.getElementsByTagName('Vertex')
249
250        # The vertices of a region will be stored in a array
251        vw = np.zeros((len(vertices), 2))
252
253        for index, vertex in enumerate(vertices):
254            # Storing the values of x and y coordinate after conversion
255            vw[index][0] = float(vertex.getAttribute('X'))
256            vw[index][1] = float(vertex.getAttribute('Y'))
257
258        # Append the vertices of a region
259        xy.append(np.int32(vw))
260
261    # Creating a completely black image
262    mask = np.zeros(shape, np.float32)
263
264    for i, contour in enumerate(xy):
265        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
266        mask[r, c] = i
267    return mask
268
269
270def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
271    """Converts .svs files to numpy array format
272
273    Argument:
274        - path: [str] - Path to the svs file
275        (below mentioned arguments are used for multi-resolution images)
276        - location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
277        - level: [int] -  target level used to read the image (default: 0)
278        - img_size: tuple[int, int] - expected size of the image
279                                      (default: None -> obtains the original shape at the expected level)
280
281    Returns:
282        the image as numpy array
283
284    TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)
285    """
286    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
287
288    from tiffslide import TiffSlide
289
290    _slide = TiffSlide(path)
291
292    if img_size is None:
293        img_size = _slide.level_dimensions[0]
294
295    img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True)
296
297    return img_arr
BIOIMAGEIO_IDS = {'covid_if': 'ilastik/covid_if_training_data', 'cremi': 'ilastik/cremi_training_data', 'dsb': 'ilastik/stardist_dsb_training_data', 'hpa': '', 'isbi2012': 'ilastik/isbi2012_neuron_segmentation_challenge', 'kasthuri': '', 'livecell': 'ilastik/livecell_dataset', 'lucchi': '', 'mitoem': 'ilastik/mitoem_segmentation_challenge', 'monuseg': 'deepimagej/monuseg_digital_pathology_miccai2018', 'ovules': '', 'plantseg_root': 'ilastik/plantseg_root', 'plantseg_ovules': 'ilastik/plantseg_ovules', 'platynereis': 'ilastik/platynereis_em_training_data', 'snemi': '', 'uro_cell': '', 'vnc': 'ilastik/vnc'}
def get_bioimageio_dataset_id(dataset_name):
47def get_bioimageio_dataset_id(dataset_name):
48    assert dataset_name in BIOIMAGEIO_IDS
49    return BIOIMAGEIO_IDS[dataset_name]
def get_checksum(filename):
52def get_checksum(filename):
53    with open(filename, "rb") as f:
54        file_ = f.read()
55        checksum = hashlib.sha256(file_).hexdigest()
56    return checksum
def download_source(path, url, download, checksum=None, verify=True):
74def download_source(path, url, download, checksum=None, verify=True):
75    if os.path.exists(path):
76        return
77    if not download:
78        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
79
80    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
81        r.raise_for_status()  # check for error
82        file_size = int(r.headers.get("Content-Length", 0))
83        desc = f"Download {url} to {path}"
84        if file_size == 0:
85            desc += " (unknown file size)"
86        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
87            copyfileobj(r_raw, f)
88
89    _check_checksum(path, checksum)
def download_source_gdrive( path, url, download, checksum=None, download_type='zip', expected_samples=10000):
 92def download_source_gdrive(path, url, download, checksum=None, download_type="zip", expected_samples=10000):
 93    if os.path.exists(path):
 94        return
 95
 96    if not download:
 97        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
 98
 99    if gdown is None:
100        raise RuntimeError(
101            "Need gdown library to download data from google drive."
102            "Please install gdown and then rerun."
103        )
104
105    print("Downloading the dataset. Might take a few minutes...")
106
107    if download_type == "zip":
108        gdown.download(url, path, quiet=False)
109        _check_checksum(path, checksum)
110    elif download_type == "folder":
111        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install `gdown==4.6.3`."
112        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
113        gdown.download_folder(url=url, output=path, quiet=True, remaining_ok=True)
114    else:
115        raise ValueError("`download_path` argument expects either `zip`/`folder`")
116    print("Download completed.")
def download_source_empiar(path, access_id, download):
119def download_source_empiar(path, access_id, download):
120    download_path = os.path.join(path, access_id)
121
122    if os.path.exists(download_path):
123        return download_path
124    if not download:
125        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
126
127    if which("ascp") is None:
128        raise RuntimeError(
129            "Need aspera-cli to download data from empiar."
130            "You can install it via 'mamba install -c hcc aspera-cli'."
131        )
132
133    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
134    if not os.path.exists(key_file):
135        conda_root = os.environ["CONDA_PREFIX"]
136        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
137
138    if not os.path.exists(key_file):
139        raise RuntimeError("Could not find the aspera ssh keyfile")
140
141    cmd = [
142        "ascp", "-QT", "-l", "200M", "-P33001",
143        "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path
144    ]
145    run(cmd)
146
147    return download_path
def download_source_kaggle(path, dataset_name, download):
150def download_source_kaggle(path, dataset_name, download):
151    if not download:
152        raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")
153
154    try:
155        from kaggle.api.kaggle_api_extended import KaggleApi
156    except ModuleNotFoundError:
157        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
158        msg += "After you have installed kaggle, you would need an API token. "
159        msg += "Follow the instructions at https://www.kaggle.com/docs/api."
160        raise ModuleNotFoundError(msg)
161
162    api = KaggleApi()
163    api.authenticate()
164    api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
def update_kwargs(kwargs, key, value, msg=None):
167def update_kwargs(kwargs, key, value, msg=None):
168    if key in kwargs:
169        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
170        warn(msg)
171    kwargs[key] = value
172    return kwargs
def unzip(zip_path, dst, remove=True):
175def unzip(zip_path, dst, remove=True):
176    with zipfile.ZipFile(zip_path, "r") as f:
177        f.extractall(dst)
178    if remove:
179        os.remove(zip_path)
def split_kwargs(function, **kwargs):
182def split_kwargs(function, **kwargs):
183    function_parameters = inspect.signature(function).parameters
184    parameter_names = list(function_parameters.keys())
185    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
186    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
187    return kwargs, other_kwargs
def ensure_transforms(ndim, **kwargs):
194def ensure_transforms(ndim, **kwargs):
195    if "raw_transform" not in kwargs:
196        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
197    if "transform" not in kwargs:
198        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
199    return kwargs
def add_instance_label_transform( kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True):
202def add_instance_label_transform(
203    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
204):
205    if binary_is_exclusive:
206        assert sum((offsets is not None, boundaries, binary)) <= 1
207    else:
208        assert sum((offsets is not None, boundaries)) <= 1
209    if offsets is not None:
210        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
211                                                                      add_binary_target=add_binary_target,
212                                                                      add_mask=True)
213        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
214        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
215        label_dtype = torch.float32
216    elif boundaries:
217        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
218        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
219        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
220        label_dtype = torch.float32
221    elif binary:
222        label_transform = torch_em.transform.label.labels_to_binary
223        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
224        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
225        label_dtype = torch.float32
226    return kwargs, label_dtype
def generate_labeled_array_from_xml(shape, xml_file):
229def generate_labeled_array_from_xml(shape, xml_file):
230    """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
231
232    Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white
233        shape: The image shape on which bit mask will be made
234        xml_file: path relative to the current working directory where the xml file is present
235
236    Returns:
237        An image of given shape with region inside contour being white..
238    """
239    # DOM object created by the minidom parser
240    xDoc = minidom.parse(xml_file)
241
242    # List of all Region tags
243    regions = xDoc.getElementsByTagName('Region')
244
245    # List which will store the vertices for each region
246    xy = []
247    for region in regions:
248        # Loading all the vertices in the region
249        vertices = region.getElementsByTagName('Vertex')
250
251        # The vertices of a region will be stored in a array
252        vw = np.zeros((len(vertices), 2))
253
254        for index, vertex in enumerate(vertices):
255            # Storing the values of x and y coordinate after conversion
256            vw[index][0] = float(vertex.getAttribute('X'))
257            vw[index][1] = float(vertex.getAttribute('Y'))
258
259        # Append the vertices of a region
260        xy.append(np.int32(vw))
261
262    # Creating a completely black image
263    mask = np.zeros(shape, np.float32)
264
265    for i, contour in enumerate(xy):
266        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
267        mask[r, c] = i
268    return mask

Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb

Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white shape: The image shape on which bit mask will be made xml_file: path relative to the current working directory where the xml file is present

Returns:

An image of given shape with region inside contour being white..

def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
271def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
272    """Converts .svs files to numpy array format
273
274    Argument:
275        - path: [str] - Path to the svs file
276        (below mentioned arguments are used for multi-resolution images)
277        - location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
278        - level: [int] -  target level used to read the image (default: 0)
279        - img_size: tuple[int, int] - expected size of the image
280                                      (default: None -> obtains the original shape at the expected level)
281
282    Returns:
283        the image as numpy array
284
285    TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)
286    """
287    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
288
289    from tiffslide import TiffSlide
290
291    _slide = TiffSlide(path)
292
293    if img_size is None:
294        img_size = _slide.level_dimensions[0]
295
296    img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True)
297
298    return img_arr

Converts .svs files to numpy array format

Argument:
  • path: [str] - Path to the svs file (below mentioned arguments are used for multi-resolution images)
  • location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
  • level: [int] - target level used to read the image (default: 0)
  • img_size: tuple[int, int] - expected size of the image (default: None -> obtains the original shape at the expected level)
Returns:

the image as numpy array

TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)