torch_em.data.datasets.light_microscopy.medussa

The MeDuSSA dataset contains annotations for bacterial membrane instance segmentation in fluorescence microscopy images stained with FM 4-64.

The dataset provides 143 training images and 16 benchmarking images of membrane-stained bacteria (primarily Bacillus subtilis PY79) with corresponding instance segmentation masks annotated using JFilament in FIJI.

The dataset is located at https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2350. This dataset is from the publication https://doi.org/10.1101/2025.10.26.684635. Please cite it if you use this dataset in your research.

  1"""The MeDuSSA dataset contains annotations for bacterial membrane
  2instance segmentation in fluorescence microscopy images stained with FM 4-64.
  3
  4The dataset provides 143 training images and 16 benchmarking images of
  5membrane-stained bacteria (primarily Bacillus subtilis PY79) with corresponding
  6instance segmentation masks annotated using JFilament in FIJI.
  7
  8The dataset is located at https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2350.
  9This dataset is from the publication https://doi.org/10.1101/2025.10.26.684635.
 10Please cite it if you use this dataset in your research.
 11"""
 12
 13import os
 14import json
 15from glob import glob
 16from typing import Union, Tuple, List, Literal
 17
 18from torch.utils.data import Dataset, DataLoader
 19
 20import torch_em
 21
 22from .. import util
 23
 24
 25BASE_URL = "https://www.ebi.ac.uk/biostudies/files/S-BIAD2350"
 26
 27SPLIT_FILE_LISTS = {
 28    "train": {
 29        "images": "submission_segmentation_training_images_raw.json",
 30        "masks": "submission_segmentation_training_masks.json",
 31    },
 32    "test": {
 33        "images": "submission_segmentation_benchmarking_images_raw.json",
 34        "masks": "submission_segmentation_benchmarking_masks.json",
 35    },
 36}
 37
 38
 39def _download_file_lists(path, split):
 40    """Download and parse JSON file lists from BioStudies to get relative file paths."""
 41    file_list_dir = os.path.join(path, "file_lists")
 42    os.makedirs(file_list_dir, exist_ok=True)
 43
 44    result = {}
 45    for key in ("images", "masks"):
 46        json_fname = SPLIT_FILE_LISTS[split][key]
 47        json_path = os.path.join(file_list_dir, json_fname)
 48
 49        if not os.path.exists(json_path):
 50            url = f"{BASE_URL}/{json_fname}"
 51            util.download_source(path=json_path, url=url, download=True, checksum=None)
 52
 53        with open(json_path) as f:
 54            data = json.load(f)
 55
 56        result[key] = sorted([entry["path"] for entry in data])
 57
 58    return result["images"], result["masks"]
 59
 60
 61def _create_h5_data(path, split, image_paths_rel, mask_paths_rel):
 62    """Create h5 files with raw images and instance labels."""
 63    import h5py
 64    import imageio.v3 as imageio
 65    from tqdm import tqdm
 66
 67    h5_dir = os.path.join(path, "h5_data", split)
 68    os.makedirs(h5_dir, exist_ok=True)
 69
 70    assert len(image_paths_rel) == len(mask_paths_rel), \
 71        f"Mismatch: {len(image_paths_rel)} images vs {len(mask_paths_rel)} masks for split '{split}'"
 72
 73    for img_rel, mask_rel in tqdm(
 74        zip(image_paths_rel, mask_paths_rel),
 75        total=len(image_paths_rel),
 76        desc=f"Creating h5 files for '{split}'"
 77    ):
 78        fname = os.path.splitext(os.path.basename(img_rel))[0]
 79        h5_path = os.path.join(h5_dir, f"{fname}.h5")
 80
 81        if os.path.exists(h5_path):
 82            continue
 83
 84        raw = imageio.imread(os.path.join(path, img_rel))
 85        labels = imageio.imread(os.path.join(path, mask_rel))
 86
 87        # Handle potential multi-dimensional images (e.g. Z-stacks not fully max-projected).
 88        if raw.ndim > 2:
 89            raw = raw.max(axis=0)
 90
 91        if labels.ndim > 2:
 92            labels = labels.max(axis=0)
 93
 94        with h5py.File(h5_path, "w") as f:
 95            f.create_dataset("raw", data=raw, compression="gzip")
 96            f.create_dataset("labels", data=labels.astype("int64"), compression="gzip")
 97
 98    return h5_dir
 99
100
101def get_medussa_data(
102    path: Union[os.PathLike, str],
103    split: Literal["train", "test"] = "train",
104    download: bool = False,
105) -> str:
106    """Download the MeDuSSA dataset.
107
108    Args:
109        path: Filepath to a folder where the downloaded data will be saved.
110        split: The data split to use. One of 'train' or 'test'.
111        download: Whether to download the data if it is not present.
112
113    Returns:
114        The filepath to the directory with the downloaded data.
115    """
116    assert split in ("train", "test"), f"'{split}' is not a valid split."
117
118    image_paths_rel, mask_paths_rel = _download_file_lists(path, split)
119
120    for rel_path in image_paths_rel + mask_paths_rel:
121        local_path = os.path.join(path, rel_path)
122        if os.path.exists(local_path):
123            continue
124
125        os.makedirs(os.path.dirname(local_path), exist_ok=True)
126        url = f"{BASE_URL}/{rel_path}"
127        util.download_source(path=local_path, url=url, download=download, checksum=None)
128
129    return path
130
131
132def get_medussa_paths(
133    path: Union[os.PathLike, str],
134    split: Literal["train", "test"] = "train",
135    download: bool = False,
136) -> List[str]:
137    """Get paths to the MeDuSSA data.
138
139    Args:
140        path: Filepath to a folder where the downloaded data will be saved.
141        split: The data split to use. One of 'train' or 'test'.
142        download: Whether to download the data if it is not present.
143
144    Returns:
145        List of filepaths for the h5 data.
146    """
147    from natsort import natsorted
148
149    get_medussa_data(path, split, download)
150
151    h5_dir = os.path.join(path, "h5_data", split)
152    if not os.path.exists(h5_dir) or len(glob(os.path.join(h5_dir, "*.h5"))) == 0:
153        image_paths_rel, mask_paths_rel = _download_file_lists(path, split)
154        _create_h5_data(path, split, image_paths_rel, mask_paths_rel)
155
156    h5_paths = natsorted(glob(os.path.join(h5_dir, "*.h5")))
157    assert len(h5_paths) > 0, f"No data found for split '{split}'"
158
159    return h5_paths
160
161
162def get_medussa_dataset(
163    path: Union[os.PathLike, str],
164    patch_shape: Tuple[int, int],
165    split: Literal["train", "test"] = "train",
166    download: bool = False,
167    **kwargs
168) -> Dataset:
169    """Get the MeDuSSA dataset for bacterial membrane segmentation.
170
171    Args:
172        path: Filepath to a folder where the downloaded data will be saved.
173        patch_shape: The patch shape to use for training.
174        split: The data split to use. One of 'train' or 'test'.
175        download: Whether to download the data if it is not present.
176        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
177
178    Returns:
179        The segmentation dataset.
180    """
181    h5_paths = get_medussa_paths(path, split, download)
182
183    kwargs, _ = util.add_instance_label_transform(
184        kwargs, add_binary_target=True,
185    )
186    kwargs = util.ensure_transforms(ndim=2, **kwargs)
187
188    return torch_em.default_segmentation_dataset(
189        raw_paths=h5_paths,
190        raw_key="raw",
191        label_paths=h5_paths,
192        label_key="labels",
193        patch_shape=patch_shape,
194        ndim=2,
195        **kwargs
196    )
197
198
199def get_medussa_loader(
200    path: Union[os.PathLike, str],
201    batch_size: int,
202    patch_shape: Tuple[int, int],
203    split: Literal["train", "test"] = "train",
204    download: bool = False,
205    **kwargs
206) -> DataLoader:
207    """Get the MeDuSSA dataloader for bacterial membrane segmentation.
208
209    Args:
210        path: Filepath to a folder where the downloaded data will be saved.
211        batch_size: The batch size for training.
212        patch_shape: The patch shape to use for training.
213        split: The data split to use. One of 'train' or 'test'.
214        download: Whether to download the data if it is not present.
215        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
216
217    Returns:
218        The DataLoader.
219    """
220    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
221    dataset = get_medussa_dataset(
222        path=path,
223        patch_shape=patch_shape,
224        split=split,
225        download=download,
226        **ds_kwargs,
227    )
228    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
BASE_URL = 'https://www.ebi.ac.uk/biostudies/files/S-BIAD2350'
SPLIT_FILE_LISTS = {'train': {'images': 'submission_segmentation_training_images_raw.json', 'masks': 'submission_segmentation_training_masks.json'}, 'test': {'images': 'submission_segmentation_benchmarking_images_raw.json', 'masks': 'submission_segmentation_benchmarking_masks.json'}}
def get_medussa_data( path: Union[os.PathLike, str], split: Literal['train', 'test'] = 'train', download: bool = False) -> str:
102def get_medussa_data(
103    path: Union[os.PathLike, str],
104    split: Literal["train", "test"] = "train",
105    download: bool = False,
106) -> str:
107    """Download the MeDuSSA dataset.
108
109    Args:
110        path: Filepath to a folder where the downloaded data will be saved.
111        split: The data split to use. One of 'train' or 'test'.
112        download: Whether to download the data if it is not present.
113
114    Returns:
115        The filepath to the directory with the downloaded data.
116    """
117    assert split in ("train", "test"), f"'{split}' is not a valid split."
118
119    image_paths_rel, mask_paths_rel = _download_file_lists(path, split)
120
121    for rel_path in image_paths_rel + mask_paths_rel:
122        local_path = os.path.join(path, rel_path)
123        if os.path.exists(local_path):
124            continue
125
126        os.makedirs(os.path.dirname(local_path), exist_ok=True)
127        url = f"{BASE_URL}/{rel_path}"
128        util.download_source(path=local_path, url=url, download=download, checksum=None)
129
130    return path

Download the MeDuSSA dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split to use. One of 'train' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

The filepath to the directory with the downloaded data.

def get_medussa_paths( path: Union[os.PathLike, str], split: Literal['train', 'test'] = 'train', download: bool = False) -> List[str]:
133def get_medussa_paths(
134    path: Union[os.PathLike, str],
135    split: Literal["train", "test"] = "train",
136    download: bool = False,
137) -> List[str]:
138    """Get paths to the MeDuSSA data.
139
140    Args:
141        path: Filepath to a folder where the downloaded data will be saved.
142        split: The data split to use. One of 'train' or 'test'.
143        download: Whether to download the data if it is not present.
144
145    Returns:
146        List of filepaths for the h5 data.
147    """
148    from natsort import natsorted
149
150    get_medussa_data(path, split, download)
151
152    h5_dir = os.path.join(path, "h5_data", split)
153    if not os.path.exists(h5_dir) or len(glob(os.path.join(h5_dir, "*.h5"))) == 0:
154        image_paths_rel, mask_paths_rel = _download_file_lists(path, split)
155        _create_h5_data(path, split, image_paths_rel, mask_paths_rel)
156
157    h5_paths = natsorted(glob(os.path.join(h5_dir, "*.h5")))
158    assert len(h5_paths) > 0, f"No data found for split '{split}'"
159
160    return h5_paths

Get paths to the MeDuSSA data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split to use. One of 'train' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the h5 data.

def get_medussa_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'test'] = 'train', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
163def get_medussa_dataset(
164    path: Union[os.PathLike, str],
165    patch_shape: Tuple[int, int],
166    split: Literal["train", "test"] = "train",
167    download: bool = False,
168    **kwargs
169) -> Dataset:
170    """Get the MeDuSSA dataset for bacterial membrane segmentation.
171
172    Args:
173        path: Filepath to a folder where the downloaded data will be saved.
174        patch_shape: The patch shape to use for training.
175        split: The data split to use. One of 'train' or 'test'.
176        download: Whether to download the data if it is not present.
177        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
178
179    Returns:
180        The segmentation dataset.
181    """
182    h5_paths = get_medussa_paths(path, split, download)
183
184    kwargs, _ = util.add_instance_label_transform(
185        kwargs, add_binary_target=True,
186    )
187    kwargs = util.ensure_transforms(ndim=2, **kwargs)
188
189    return torch_em.default_segmentation_dataset(
190        raw_paths=h5_paths,
191        raw_key="raw",
192        label_paths=h5_paths,
193        label_key="labels",
194        patch_shape=patch_shape,
195        ndim=2,
196        **kwargs
197    )

Get the MeDuSSA dataset for bacterial membrane segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. One of 'train' or 'test'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_medussa_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'test'] = 'train', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
200def get_medussa_loader(
201    path: Union[os.PathLike, str],
202    batch_size: int,
203    patch_shape: Tuple[int, int],
204    split: Literal["train", "test"] = "train",
205    download: bool = False,
206    **kwargs
207) -> DataLoader:
208    """Get the MeDuSSA dataloader for bacterial membrane segmentation.
209
210    Args:
211        path: Filepath to a folder where the downloaded data will be saved.
212        batch_size: The batch size for training.
213        patch_shape: The patch shape to use for training.
214        split: The data split to use. One of 'train' or 'test'.
215        download: Whether to download the data if it is not present.
216        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
217
218    Returns:
219        The DataLoader.
220    """
221    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
222    dataset = get_medussa_dataset(
223        path=path,
224        patch_shape=patch_shape,
225        split=split,
226        download=download,
227        **ds_kwargs,
228    )
229    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)

Get the MeDuSSA dataloader for bacterial membrane segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. One of 'train' or 'test'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.