1import inspect
  2import os
  3import hashlib
  4import zipfile
  5import numpy as np
  6from tqdm import tqdm
  7from warnings import warn
  8from xml.dom import minidom
  9from shutil import copyfileobj, which
 10from subprocess import run
 11from packaging import version
 13from skimage.draw import polygon
 15import torch
 16import torch_em
 17import requests
 20    import gdown
 21except ImportError:
 22    gdown = None
 26    "covid_if": "ilastik/covid_if_training_data",
 27    "cremi": "ilastik/cremi_training_data",
 28    "dsb": "ilastik/stardist_dsb_training_data",
 29    "hpa": "",  # not on bioimageio yet
 30    "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge",
 31    "kasthuri": "",  # not on bioimageio yet:
 32    "livecell": "ilastik/livecell_dataset",
 33    "lucchi": "",  # not on bioimageio yet:
 34    "mitoem": "ilastik/mitoem_segmentation_challenge",
 35    "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018",
 36    "ovules": "",  # not on bioimageio yet
 37    "plantseg_root": "ilastik/plantseg_root",
 38    "plantseg_ovules": "ilastik/plantseg_ovules",
 39    "platynereis": "ilastik/platynereis_em_training_data",
 40    "snemi": "",  # not on bioimagegio yet
 41    "uro_cell": "",  # not on bioimageio yet:
 42    "vnc": "ilastik/vnc",
 46def get_bioimageio_dataset_id(dataset_name):
 47    assert dataset_name in BIOIMAGEIO_IDS
 48    return BIOIMAGEIO_IDS[dataset_name]
 51def get_checksum(filename):
 52    with open(filename, "rb") as f:
 53        file_ =
 54        checksum = hashlib.sha256(file_).hexdigest()
 55    return checksum
 58def _check_checksum(path, checksum):
 59    if checksum is not None:
 60        this_checksum = get_checksum(path)
 61        if this_checksum != checksum:
 62            raise RuntimeError(
 63                "The checksum of the download does not match the expected checksum."
 64                f"Expected: {checksum}, got: {this_checksum}"
 65            )
 66        print("Download successful and checksums agree.")
 67    else:
 68        warn("The file was downloaded, but no checksum was provided, so the file may be corrupted.")
 71# this needs to be extended to support download from s3 via boto,
 72# if we get a resource that is available via s3 without support for http
 73def download_source(path, url, download, checksum=None, verify=True):
 74    if os.path.exists(path):
 75        return
 76    if not download:
 77        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
 79    with requests.get(url, stream=True, allow_redirects=True, verify=verify) as r:
 80        r.raise_for_status()  # check for error
 81        file_size = int(r.headers.get("Content-Length", 0))
 82        desc = f"Download {url} to {path}"
 83        if file_size == 0:
 84            desc += " (unknown file size)"
 85        with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
 86            copyfileobj(r_raw, f)
 88    _check_checksum(path, checksum)
 91def download_source_gdrive(path, url, download, checksum=None, download_type="zip", expected_samples=10000):
 92    if os.path.exists(path):
 93        return
 95    if not download:
 96        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
 98    if gdown is None:
 99        raise RuntimeError(
100            "Need gdown library to download data from google drive."
101            "Please install gdown and then rerun."
102        )
104    print("Downloading the dataset. Might take a few minutes...")
106    if download_type == "zip":
107, path, quiet=False)
108        _check_checksum(path, checksum)
109    elif download_type == "folder":
110        assert version.parse(gdown.__version__) == version.parse("4.6.3"), "Please install `gdown==4.6.3`."
111        gdown.download_folder.__globals__["MAX_NUMBER_FILES"] = expected_samples
112        gdown.download_folder(url=url, output=path, quiet=True, remaining_ok=True)
113    else:
114        raise ValueError("`download_path` argument expects either `zip`/`folder`")
115    print("Download completed.")
118def download_source_empiar(path, access_id, download):
119    download_path = os.path.join(path, access_id)
121    if os.path.exists(download_path):
122        return download_path
123    if not download:
124        raise RuntimeError(f"Cannot find the data at {path}, but download was set to False")
126    if which("ascp") is None:
127        raise RuntimeError(
128            "Need aspera-cli to download data from empiar."
129            "You can install it via 'mamba install -c hcc aspera-cli'."
130        )
132    key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh")
133    if not os.path.exists(key_file):
134        conda_root = os.environ["CONDA_PREFIX"]
135        key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh")
137    if not os.path.exists(key_file):
138        raise RuntimeError("Could not find the aspera ssh keyfile")
140    cmd = [
141        "ascp", "-QT", "-l", "200M", "-P33001",
142        "-i", key_file, f"{access_id}", path
143    ]
144    run(cmd)
146    return download_path
149def download_source_kaggle(path, dataset_name, download):
150    if not download:
151        raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")
153    try:
154        from kaggle.api.kaggle_api_extended import KaggleApi
155    except ModuleNotFoundError:
156        msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
157        msg += "After you have installed kaggle, you would need an API token. "
158        msg += "Follow the instructions at"
159        raise ModuleNotFoundError(msg)
161    api = KaggleApi()
162    api.authenticate()
163    api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)
166def update_kwargs(kwargs, key, value, msg=None):
167    if key in kwargs:
168        msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
169        warn(msg)
170    kwargs[key] = value
171    return kwargs
174def unzip(zip_path, dst, remove=True):
175    with zipfile.ZipFile(zip_path, "r") as f:
176        f.extractall(dst)
177    if remove:
178        os.remove(zip_path)
181def split_kwargs(function, **kwargs):
182    function_parameters = inspect.signature(function).parameters
183    parameter_names = list(function_parameters.keys())
184    other_kwargs = {k: v for k, v in kwargs.items() if k not in parameter_names}
185    kwargs = {k: v for k, v in kwargs.items() if k in parameter_names}
186    return kwargs, other_kwargs
189# this adds the default transforms for 'raw_transform' and 'transform'
190# in case these were not specified in the kwargs
191# this is NOT necessary if 'default_segmentation_dataset' is used, only if a dataset class
192# is used directly, e.g. in the LiveCell Loader
193def ensure_transforms(ndim, **kwargs):
194    if "raw_transform" not in kwargs:
195        kwargs = update_kwargs(kwargs, "raw_transform", torch_em.transform.get_raw_transform())
196    if "transform" not in kwargs:
197        kwargs = update_kwargs(kwargs, "transform", torch_em.transform.get_augmentations(ndim=ndim))
198    return kwargs
201def add_instance_label_transform(
202    kwargs, add_binary_target, label_dtype=None, binary=False, boundaries=False, offsets=None, binary_is_exclusive=True,
204    if binary_is_exclusive:
205        assert sum((offsets is not None, boundaries, binary)) <= 1
206    else:
207        assert sum((offsets is not None, boundaries)) <= 1
208    if offsets is not None:
209        label_transform2 = torch_em.transform.label.AffinityTransform(offsets=offsets,
210                                                                      add_binary_target=add_binary_target,
211                                                                      add_mask=True)
212        msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
213        kwargs = update_kwargs(kwargs, "label_transform2", label_transform2, msg=msg)
214        label_dtype = torch.float32
215    elif boundaries:
216        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=add_binary_target)
217        msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
218        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
219        label_dtype = torch.float32
220    elif binary:
221        label_transform = torch_em.transform.label.labels_to_binary
222        msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
223        kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
224        label_dtype = torch.float32
225    return kwargs, label_dtype
228def generate_labeled_array_from_xml(shape, xml_file):
229    """Function taken from:
231    Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white
232        shape: The image shape on which bit mask will be made
233        xml_file: path relative to the current working directory where the xml file is present
235    Returns:
236        An image of given shape with region inside contour being white..
237    """
238    # DOM object created by the minidom parser
239    xDoc = minidom.parse(xml_file)
241    # List of all Region tags
242    regions = xDoc.getElementsByTagName('Region')
244    # List which will store the vertices for each region
245    xy = []
246    for region in regions:
247        # Loading all the vertices in the region
248        vertices = region.getElementsByTagName('Vertex')
250        # The vertices of a region will be stored in a array
251        vw = np.zeros((len(vertices), 2))
253        for index, vertex in enumerate(vertices):
254            # Storing the values of x and y coordinate after conversion
255            vw[index][0] = float(vertex.getAttribute('X'))
256            vw[index][1] = float(vertex.getAttribute('Y'))
258        # Append the vertices of a region
259        xy.append(np.int32(vw))
261    # Creating a completely black image
262    mask = np.zeros(shape, np.float32)
264    for i, contour in enumerate(xy):
265        r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
266        mask[r, c] = i
267    return mask
270def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
271    """Converts .svs files to numpy array format
273    Argument:
274        - path: [str] - Path to the svs file
275        (below mentioned arguments are used for multi-resolution images)
276        - location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
277        - level: [int] -  target level used to read the image (default: 0)
278        - img_size: tuple[int, int] - expected size of the image
279                                      (default: None -> obtains the original shape at the expected level)
281    Returns:
282        the image as numpy array
284    TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)
285    """
286    assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"
288    from tiffslide import TiffSlide
290    _slide = TiffSlide(path)
292    if img_size is None:
293        img_size = _slide.level_dimensions[0]
295    img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True)
297    return img_arr
