torch_em.cli

  1import argparse
  2import json
  3import multiprocessing
  4import uuid
  5
  6import imageio.v3 as imageio
  7import torch
  8import torch_em
  9from elf.io import open_file
 10from torch.utils.data import random_split
 11from torch_em.model.unet import AnisotropicUNet, UNet2d, UNet3d
 12from torch_em.util.prediction import predict_with_halo, predict_with_padding
 13
 14
 15#
 16# CLI for training
 17#
 18
 19
 20def _get_training_parser(description):
 21    parser = argparse.ArgumentParser(description=description)
 22
 23    # paths and keys for the data
 24    # training inputs and labels are required
 25    parser.add_argument("-i", "--training_inputs",
 26                        help="The input file path(s). Supports common image formats (tif, png, etc)"
 27                        "as well as container formats like hdf5 and zarr. For the latter 'training_input_key'"
 28                        "also has to be provided. In case you have a folder with many images you should provide the"
 29                        "path to the folder instead of individual image paths; for this you then need to provide the"
 30                        "file pattern (e.g. '*.tif') to 'training_input_key'.",
 31                        required=True, type=str, nargs="+")
 32    parser.add_argument("-l",  "--training_labels",
 33                        help="The label file path(s). See 'training_inputs' for details on the supported formats etc.",
 34                        required=True, type=str, nargs="+")
 35    parser.add_argument("-k", "--training_input_key",
 36                        help="The key (internal path) for the input data. Required for data formats like hdf5 or zarr.")
 37    parser.add_argument("--training_label_key", help="The key for the labels. See also 'training_input_key'")
 38
 39    # val inputs and labels are optional; if not given we split off parts of the training data
 40    parser.add_argument("--validation_inputs", type=str, nargs="+",
 41                        help="The input file path(s) for validation data. If this is not given"
 42                        "a fraction of the training inputs will be used for validation.")
 43    parser.add_argument("--validation_labels", type=str, nargs="+",
 44                        help="The label file path(s) for validation. Must be given if 'validation_inputs' are given.")
 45    parser.add_argument("--validation_input_key", help="The key for the validation inputs.")
 46    parser.add_argument("--validation_label_key", help="The key for the validation labels.")
 47
 48    # other options
 49    parser.add_argument("-b", "--batch_size", type=int, required=True, help="The batch size.")
 50    parser.add_argument("-p", "--patch_shape", type=int, nargs="+", required=True,
 51                        help="The training patch shape")
 52    parser.add_argument("-n", "--n_iterations", type=int, default=25000,
 53                        help="The number of iterations to train for.")
 54    parser.add_argument("-m", "--label_mode",
 55                        help="The label mode determines the transformation applied to the"
 56                        "labels in order to obtain the targets for training."
 57                        "This can be used to obtain suitable representations for training given"
 58                        "instance segmentation ground-truth. Currently supported:"
 59                        "'affinities', 'affinities_with_foreground',"
 60                        "'boundaries', 'boundaries_with_foreground', 'foreground'.")
 61    parser.add_argument("--name", help="The name of the trained model (checkpoint).")
 62    parser.add_argument("--train_fraction", type=float, default=0.8,
 63                        help="The fraction of the data that will be used for training."
 64                        "The rest of the data will be used for validation."
 65                        "This is only used if validation data is not provided,"
 66                        "otherwise all data will be used for training.")
 67
 68    return parser
 69
 70
 71# TODO provide an option to over-ride the offsets, e.g. via filepath to a json?
 72def _get_offsets(ndim, scale_factors):
 73    if ndim == 2:
 74        offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9], [-27, 0], [0, -27]]
 75    elif ndim == 3 and scale_factors is None:
 76        offsets = [
 77            [-1, 0, 0], [0, -1, 0], [0, 0, -1],
 78            [-3, 0, 0], [0, -3, 0], [0, 0, -3],
 79            [-9, 0, 0], [0, -9, 0], [0, 0, -9],
 80            [-27, 0, 0], [0, -27, 0], [0, 0, -27],
 81        ]
 82    else:
 83        offsets = [
 84            [-1, 0, 0], [0, -1, 0], [0, 0, -1],
 85            [-2, 0, 0], [0, -3, 0], [0, 0, -3],
 86            [-3, 0, 0], [0, -9, 0], [0, 0, -9],
 87            [-4, 0, 0], [0, -27, 0], [0, 0, -27],
 88        ]
 89    return offsets
 90
 91
 92# TODO this should be extended to all relevant things, generalized to more datasets and and refactored to torch_em.util
 93def _random_split(ds, fractions):
 94
 95    def _get_attribute(dataset, attr_name):
 96        while isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
 97            dataset = dataset.datasets[0]
 98        attr = getattr(dataset, attr_name)
 99        return attr
100
101    ds_train, ds_val = random_split(ds, fractions)
102
103    raw_transform = _get_attribute(ds, "raw_transform")
104    ds_train.raw_transform = raw_transform
105    ds_val.raw_transform = raw_transform
106
107    ndim = _get_attribute(ds, "ndim")
108    ds_train.ndim = ndim
109    ds_val.ndim = ndim
110
111    return ds_train, ds_val
112
113
114def _get_loader(input_paths, input_key, label_paths, label_key, args, ndim, perform_split=False):
115    label_transform, label_transform2 = None, None
116
117    # figure out the label transformations
118    label_modes = (
119        "affinties", "affinities_and_foreground",
120        "boundaries", "boundaries_and_foreground",
121        "foreground",
122    )
123    if args.label_mode is None:
124        pass
125    elif args.label_mode == "affinities":
126        offsets = _get_offsets(ndim, args.scale_factors)
127        label_transform = torch_em.transform.label.AffinityTransform(
128            offsets=offsets, add_binary_target=False, add_mask=True,
129        )
130    elif args.label_mode == "affinities_and_foreground":
131        label_transform = torch_em.transform.label.AffinityTransform(
132            offsets=_get_offsets(ndim, args.scale_factors), add_binary_target=True, add_mask=True,
133        )
134    elif args.label_mode == "boundaries":
135        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False)
136    elif args.label_mode == "boundaries_and_foreground":
137        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
138    elif args.label_mode == "foreground":
139        label_transform = torch_em.transform.label.labels_to_binary
140    else:
141        raise ValueError(f"Unknown label mode {args.label_model}, expect one of {label_modes}")
142
143    # validate the patch shape
144    patch_shape = args.patch_shape
145    if ndim == 2:
146        if len(patch_shape) != 2 and patch_shape[0] != 1:
147            raise ValueError(f"Invalid patch_shape {patch_shape} for 2d data.")
148    elif ndim == 3:
149        if len(patch_shape) != 3:
150            raise ValueError(f"Invalid patch_shape {patch_shape} for 3d data.")
151    else:
152        raise RuntimeError(f"Invalid ndim: {ndim}")
153
154    # TODO figure out if with channels
155    ds = torch_em.default_segmentation_dataset(
156        input_paths, input_key, label_paths, label_key,
157        patch_shape=patch_shape, ndim=ndim,
158        label_transform=label_transform,
159        label_transform2=label_transform2,
160    )
161
162    n_cpus = multiprocessing.cpu_count()
163    if perform_split:
164        fractions = [args.train_fraction, 1.0 - args.train_fraction]
165        ds_train, ds_val = _random_split(ds, fractions)
166        train_loader = torch_em.segmentation.get_data_loader(
167            ds_train, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
168        )
169        val_loader = torch_em.segmentation.get_data_loader(
170            ds_val, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
171        )
172        return train_loader, val_loader
173    else:
174        loader = torch_em.segmentation.get_data_loader(
175            ds, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
176        )
177    return loader
178
179
180def _get_loaders(args, ndim):
181    # if validation data is not passed we split the loader
182    if args.validation_inputs is None:
183        print("You haven't provided validation data so the validation set will be split off the input data.")
184        print(f"A fraction of {args.train_fraction} will be used for training and {1 - args.train_fraction} for val.")
185        train_loader, val_loader = _get_loader(
186            args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key,
187            args=args, ndim=ndim, perform_split=True,
188        )
189    else:
190        train_loader = _get_loader(
191            args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key,
192            args=args, ndim=ndim,
193        )
194        val_loader = _get_loader(
195            args.validation_inputs, args.validation_key, args.validation_labels, args.validation_label_key,
196            args=args, ndim=ndim,
197        )
198    return train_loader, val_loader
199
200
201def _determine_channels(train_loader, args):
202    x, y = next(iter(train_loader))
203    in_channels = x.shape[1]
204    out_channels = y.shape[1]
205    return in_channels, out_channels
206
207
208def train_2d_unet():
209    parser = _get_training_parser("Train a 2D UNet.")
210    args = parser.parse_args()
211
212    train_loader, val_loader = _get_loaders(args, ndim=2)
213    # TODO more unet settings
214    # create the 2d unet
215    in_channels, out_channels = _determine_channels(train_loader, args)
216    model = UNet2d(in_channels, out_channels)
217
218    if "affinities" in args.label_mode:
219        loss = torch_em.loss.LossWrapper(
220            torch_em.loss.DiceLoss(),
221            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
222        )
223    else:
224        loss = torch_em.loss.DiceLoss()
225
226    # generate a random id for the training
227    name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
228    print("Start 2d unet training for", name)
229    trainer = torch_em.default_segmentation_trainer(
230        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
231        loss=loss, metric=loss, compile_model=False,
232    )
233    trainer.fit(args.n_iterations)
234
235
236def train_3d_unet():
237    parser = _get_training_parser("Train a 3D UNet.")
238    parser.add_argument("-s", "--scale_factors", type=str,
239                        help="The scale factors for the downsampling factures of the 3D U-Net."
240                        "Can be used to set anisotropic scaling of the U-Net."
241                        "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set"
242                        "anisotropic in the first layer and isotropic scaling in the other two."
243                        "If not passed an isotropic 3D U-Net will be saved.")
244    args = parser.parse_args()
245
246    scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors)
247    train_loader, val_loader = _get_loaders(args, ndim=3)
248
249    # TODO more unet settings
250    # create the 3d unet
251    in_channels, out_channels = _determine_channels(train_loader, args)
252    if scale_factors is None:
253        model = UNet3d(in_channels, out_channels)
254    else:
255        model = AnisotropicUNet(in_channels, out_channels, scale_factors)
256
257    if "affinities" in args.label_mode:
258        loss = torch_em.loss.LossWrapper(
259            torch_em.loss.DiceLoss(),
260            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
261        )
262    else:
263        loss = torch_em.loss.DiceLoss()
264
265    # generate a random id for the training
266    name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
267    print("Start 3d unet training for", name)
268    trainer = torch_em.default_segmentation_trainer(
269        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
270        loss=loss, metric=loss, compile_model=False,
271    )
272    trainer.fit(args.n_iterations)
273
274
275#
276# CLI for prediction
277#
278
279
280def _get_prediction_parser(description):
281    parser = argparse.ArgumentParser(description=description)
282    parser.add_argument("-c", "--checkpoint", required=True, help="The model checkpoint to use for prediction.")
283    parser.add_argument("-i", "--input_path", required=True,
284                        help="The input path. Supports common image formats (tif, png, etc)"
285                        "as well as container formats like hdf5 and zarr. For the latter 'input_key' is also required.")
286    parser.add_argument("-k", "--input_key", help="The key (path in file) of the input data."
287                        "Required if the input data is a container file format (e.g. hdf5).")
288    parser.add_argument("-o", "--output_path", required=True, help="The path where to save the prediction.")
289    parser.add_argument("--output_key", help="The key for saving the output path. Required for container file formats.")
290    parser.add_argument("-p", "--preprocess", default="standardize")
291    parser.add_argument("--chunks", nargs="+", type=int,
292                        help="The chunks for the serialized prediction. Only relevant for container file formats.")
293    parser.add_argument("--compression", help="The compression to use when saving the prediction.")
294    return parser
295
296
297def _prediction(args, predict, device):
298    model = torch_em.util.get_trainer(args.checkpoint, device=device).model
299
300    if args.input_key is None:
301        input_ = imageio.imread(args.input_path)
302        pred = predict(model, input_)
303    else:
304        with open_file(args.input_path, "r") as f:
305            input_ = f[args.input_key]
306            pred = predict(model, input_)
307
308    output_key = args.output_key
309    if output_key is None:
310        imageio.imwrite(args.output_path, pred)
311    else:
312        kwargs = {}
313        if args.chunks is not None:
314            assert len(args.chunks) == pred.ndim
315            kwargs["chunks"] = args.chunks
316        if args.compression is not None:
317            kwargs["compression"] = args.compression
318        with open_file(args.output_path, "a") as f:
319            ds = f.require_dataset(
320                output_key, shape=pred.shape, dtype=pred.dtype, **kwargs
321            )
322            ds.n_threads = multiprocessing.cpu_count()
323            ds[:] = pred
324
325
326def predict():
327    parser = _get_prediction_parser("Run prediction (with padding if necessary).")
328    parser.add_argument("--min_divisible", nargs="+", type=int,
329                        help="The minimal divisible factors for the input shape of the models."
330                        "If given the input will be padded to be divisible by these factors.")
331    parser.add_argument("-d", "--device",
332                        help="The device (gpu, cpu) to use for prediction."
333                        "By default a gpu will be used if available, otherwise the cpu will be used.")
334    args = parser.parse_args()
335
336    preprocess = getattr(torch_em.transform.raw, args.preprocess)
337    if args.device is None:
338        device = "cuda" if torch.cuda.is_available() else "cpu"
339    else:
340        device = args.device
341
342    # TODO enable prediction with channels
343    def predict(model, input_):
344        if args.min_divisible is None:
345            with torch.no_grad():
346                input_ = preprocess(input_)
347                input_ = torch.from_numpy(input_[:][None, None]).to(device)
348                pred = model(input_)
349            pred = pred.cpu().numpy().squeeze()
350        else:
351            input_ = preprocess(input_[:])
352            pred = predict_with_padding(input_, model, args.min_divisible, device)
353        return pred
354
355    _prediction(args, predict, device)
356
357
358def _pred_2d(model, input_):
359    assert input_.shape[2] == 1
360    pred = model(input_[:, :, 0])
361    return pred[:, :, None]
362
363
364def predict_with_tiling():
365    parser = _get_prediction_parser("Run prediction over tiled input.")
366    parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int,
367                        help="The shape of the blocks that will be used to tile the input."
368                        "The model will be applied to each block individually and the results will be stitched.")
369    parser.add_argument("--halo", nargs="+", type=int,
370                        help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.")
371    parser.add_argument("-d", "--devices", nargs="+",
372                        help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus."
373                        "By default a gpu will be used if available, otherwise the cpu will be used.")
374    args = parser.parse_args()
375
376    block_shape = args.block_shape
377    preprocess = getattr(torch_em.transform.raw, args.preprocess)
378    if args.halo is None:
379        halo = [0] * len(block_shape)
380    else:
381        halo = args.halo
382    assert len(halo) == len(block_shape)
383
384    if args.devices is None:
385        devices = ["cuda"] if torch.cuda.is_available() else ["cpu"]
386    else:
387        devices = args.devices
388
389    # if the block shape is singleton in the first axis we assume that this is a 2d model
390    if block_shape[0] == 1:
391        pred_function = _pred_2d
392    else:
393        pred_function = None
394
395    # TODO enable prediction with channels
396    def predict(model, input_):
397        pred = predict_with_halo(
398            input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo,
399            prediction_function=pred_function, preprocess=preprocess
400        )
401        return pred
402
403    _prediction(args, predict, devices[0])
def train_2d_unet():
209def train_2d_unet():
210    parser = _get_training_parser("Train a 2D UNet.")
211    args = parser.parse_args()
212
213    train_loader, val_loader = _get_loaders(args, ndim=2)
214    # TODO more unet settings
215    # create the 2d unet
216    in_channels, out_channels = _determine_channels(train_loader, args)
217    model = UNet2d(in_channels, out_channels)
218
219    if "affinities" in args.label_mode:
220        loss = torch_em.loss.LossWrapper(
221            torch_em.loss.DiceLoss(),
222            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
223        )
224    else:
225        loss = torch_em.loss.DiceLoss()
226
227    # generate a random id for the training
228    name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
229    print("Start 2d unet training for", name)
230    trainer = torch_em.default_segmentation_trainer(
231        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
232        loss=loss, metric=loss, compile_model=False,
233    )
234    trainer.fit(args.n_iterations)
def train_3d_unet():
237def train_3d_unet():
238    parser = _get_training_parser("Train a 3D UNet.")
239    parser.add_argument("-s", "--scale_factors", type=str,
240                        help="The scale factors for the downsampling factures of the 3D U-Net."
241                        "Can be used to set anisotropic scaling of the U-Net."
242                        "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set"
243                        "anisotropic in the first layer and isotropic scaling in the other two."
244                        "If not passed an isotropic 3D U-Net will be saved.")
245    args = parser.parse_args()
246
247    scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors)
248    train_loader, val_loader = _get_loaders(args, ndim=3)
249
250    # TODO more unet settings
251    # create the 3d unet
252    in_channels, out_channels = _determine_channels(train_loader, args)
253    if scale_factors is None:
254        model = UNet3d(in_channels, out_channels)
255    else:
256        model = AnisotropicUNet(in_channels, out_channels, scale_factors)
257
258    if "affinities" in args.label_mode:
259        loss = torch_em.loss.LossWrapper(
260            torch_em.loss.DiceLoss(),
261            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
262        )
263    else:
264        loss = torch_em.loss.DiceLoss()
265
266    # generate a random id for the training
267    name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
268    print("Start 3d unet training for", name)
269    trainer = torch_em.default_segmentation_trainer(
270        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
271        loss=loss, metric=loss, compile_model=False,
272    )
273    trainer.fit(args.n_iterations)
def predict():
327def predict():
328    parser = _get_prediction_parser("Run prediction (with padding if necessary).")
329    parser.add_argument("--min_divisible", nargs="+", type=int,
330                        help="The minimal divisible factors for the input shape of the models."
331                        "If given the input will be padded to be divisible by these factors.")
332    parser.add_argument("-d", "--device",
333                        help="The device (gpu, cpu) to use for prediction."
334                        "By default a gpu will be used if available, otherwise the cpu will be used.")
335    args = parser.parse_args()
336
337    preprocess = getattr(torch_em.transform.raw, args.preprocess)
338    if args.device is None:
339        device = "cuda" if torch.cuda.is_available() else "cpu"
340    else:
341        device = args.device
342
343    # TODO enable prediction with channels
344    def predict(model, input_):
345        if args.min_divisible is None:
346            with torch.no_grad():
347                input_ = preprocess(input_)
348                input_ = torch.from_numpy(input_[:][None, None]).to(device)
349                pred = model(input_)
350            pred = pred.cpu().numpy().squeeze()
351        else:
352            input_ = preprocess(input_[:])
353            pred = predict_with_padding(input_, model, args.min_divisible, device)
354        return pred
355
356    _prediction(args, predict, device)
def predict_with_tiling():
365def predict_with_tiling():
366    parser = _get_prediction_parser("Run prediction over tiled input.")
367    parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int,
368                        help="The shape of the blocks that will be used to tile the input."
369                        "The model will be applied to each block individually and the results will be stitched.")
370    parser.add_argument("--halo", nargs="+", type=int,
371                        help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.")
372    parser.add_argument("-d", "--devices", nargs="+",
373                        help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus."
374                        "By default a gpu will be used if available, otherwise the cpu will be used.")
375    args = parser.parse_args()
376
377    block_shape = args.block_shape
378    preprocess = getattr(torch_em.transform.raw, args.preprocess)
379    if args.halo is None:
380        halo = [0] * len(block_shape)
381    else:
382        halo = args.halo
383    assert len(halo) == len(block_shape)
384
385    if args.devices is None:
386        devices = ["cuda"] if torch.cuda.is_available() else ["cpu"]
387    else:
388        devices = args.devices
389
390    # if the block shape is singleton in the first axis we assume that this is a 2d model
391    if block_shape[0] == 1:
392        pred_function = _pred_2d
393    else:
394        pred_function = None
395
396    # TODO enable prediction with channels
397    def predict(model, input_):
398        pred = predict_with_halo(
399            input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo,
400            prediction_function=pred_function, preprocess=preprocess
401        )
402        return pred
403
404    _prediction(args, predict, devices[0])