torch_em.cli

@private

  1"""@private
  2"""
  3import argparse
  4import json
  5import multiprocessing
  6import uuid
  7
  8import imageio.v3 as imageio
  9import torch
 10import torch_em
 11from elf.io import open_file
 12from torch.utils.data import random_split
 13from torch_em.model.unet import AnisotropicUNet, UNet2d, UNet3d
 14from torch_em.util.prediction import predict_with_halo, predict_with_padding
 15
 16
 17#
 18# CLI for training
 19#
 20
 21
 22def _get_training_parser(description):
 23    parser = argparse.ArgumentParser(description=description)
 24
 25    # paths and keys for the data
 26    # training inputs and labels are required
 27    parser.add_argument("-i", "--training_inputs",
 28                        help="The input file path(s). Supports common image formats (tif, png, etc)"
 29                        "as well as container formats like hdf5 and zarr. For the latter 'training_input_key'"
 30                        "also has to be provided. In case you have a folder with many images you should provide the"
 31                        "path to the folder instead of individual image paths; for this you then need to provide the"
 32                        "file pattern (e.g. '*.tif') to 'training_input_key'.",
 33                        required=True, type=str, nargs="+")
 34    parser.add_argument("-l",  "--training_labels",
 35                        help="The label file path(s). See 'training_inputs' for details on the supported formats etc.",
 36                        required=True, type=str, nargs="+")
 37    parser.add_argument("-k", "--training_input_key",
 38                        help="The key (internal path) for the input data. Required for data formats like hdf5 or zarr.")
 39    parser.add_argument("--training_label_key", help="The key for the labels. See also 'training_input_key'")
 40
 41    # val inputs and labels are optional; if not given we split off parts of the training data
 42    parser.add_argument("--validation_inputs", type=str, nargs="+",
 43                        help="The input file path(s) for validation data. If this is not given"
 44                        "a fraction of the training inputs will be used for validation.")
 45    parser.add_argument("--validation_labels", type=str, nargs="+",
 46                        help="The label file path(s) for validation. Must be given if 'validation_inputs' are given.")
 47    parser.add_argument("--validation_input_key", help="The key for the validation inputs.")
 48    parser.add_argument("--validation_label_key", help="The key for the validation labels.")
 49
 50    # other options
 51    parser.add_argument("-b", "--batch_size", type=int, required=True, help="The batch size.")
 52    parser.add_argument("-p", "--patch_shape", type=int, nargs="+", required=True,
 53                        help="The training patch shape")
 54    parser.add_argument("-n", "--n_iterations", type=int, default=25000,
 55                        help="The number of iterations to train for.")
 56    parser.add_argument("-m", "--label_mode",
 57                        help="The label mode determines the transformation applied to the"
 58                        "labels in order to obtain the targets for training."
 59                        "This can be used to obtain suitable representations for training given"
 60                        "instance segmentation ground-truth. Currently supported:"
 61                        "'affinities', 'affinities_with_foreground',"
 62                        "'boundaries', 'boundaries_with_foreground', 'foreground'.")
 63    parser.add_argument("--name", help="The name of the trained model (checkpoint).")
 64    parser.add_argument("--train_fraction", type=float, default=0.8,
 65                        help="The fraction of the data that will be used for training."
 66                        "The rest of the data will be used for validation."
 67                        "This is only used if validation data is not provided,"
 68                        "otherwise all data will be used for training.")
 69
 70    return parser
 71
 72
 73# TODO provide an option to over-ride the offsets, e.g. via filepath to a json?
 74def _get_offsets(ndim, scale_factors):
 75    if ndim == 2:
 76        offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9], [-27, 0], [0, -27]]
 77    elif ndim == 3 and scale_factors is None:
 78        offsets = [
 79            [-1, 0, 0], [0, -1, 0], [0, 0, -1],
 80            [-3, 0, 0], [0, -3, 0], [0, 0, -3],
 81            [-9, 0, 0], [0, -9, 0], [0, 0, -9],
 82            [-27, 0, 0], [0, -27, 0], [0, 0, -27],
 83        ]
 84    else:
 85        offsets = [
 86            [-1, 0, 0], [0, -1, 0], [0, 0, -1],
 87            [-2, 0, 0], [0, -3, 0], [0, 0, -3],
 88            [-3, 0, 0], [0, -9, 0], [0, 0, -9],
 89            [-4, 0, 0], [0, -27, 0], [0, 0, -27],
 90        ]
 91    return offsets
 92
 93
 94# TODO this should be extended to all relevant things, generalized to more datasets and and refactored to torch_em.util
 95def _random_split(ds, fractions):
 96
 97    def _get_attribute(dataset, attr_name):
 98        while isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
 99            dataset = dataset.datasets[0]
100        attr = getattr(dataset, attr_name)
101        return attr
102
103    ds_train, ds_val = random_split(ds, fractions)
104
105    raw_transform = _get_attribute(ds, "raw_transform")
106    ds_train.raw_transform = raw_transform
107    ds_val.raw_transform = raw_transform
108
109    ndim = _get_attribute(ds, "ndim")
110    ds_train.ndim = ndim
111    ds_val.ndim = ndim
112
113    return ds_train, ds_val
114
115
116def _get_loader(input_paths, input_key, label_paths, label_key, args, ndim, perform_split=False):
117    label_transform, label_transform2 = None, None
118
119    # figure out the label transformations
120    label_modes = (
121        "affinties", "affinities_and_foreground",
122        "boundaries", "boundaries_and_foreground",
123        "foreground",
124    )
125    if args.label_mode is None:
126        pass
127    elif args.label_mode == "affinities":
128        offsets = _get_offsets(ndim, args.scale_factors)
129        label_transform = torch_em.transform.label.AffinityTransform(
130            offsets=offsets, add_binary_target=False, add_mask=True,
131        )
132    elif args.label_mode == "affinities_and_foreground":
133        label_transform = torch_em.transform.label.AffinityTransform(
134            offsets=_get_offsets(ndim, args.scale_factors), add_binary_target=True, add_mask=True,
135        )
136    elif args.label_mode == "boundaries":
137        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False)
138    elif args.label_mode == "boundaries_and_foreground":
139        label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
140    elif args.label_mode == "foreground":
141        label_transform = torch_em.transform.label.labels_to_binary
142    else:
143        raise ValueError(f"Unknown label mode {args.label_model}, expect one of {label_modes}")
144
145    # validate the patch shape
146    patch_shape = args.patch_shape
147    if ndim == 2:
148        if len(patch_shape) != 2 and patch_shape[0] != 1:
149            raise ValueError(f"Invalid patch_shape {patch_shape} for 2d data.")
150    elif ndim == 3:
151        if len(patch_shape) != 3:
152            raise ValueError(f"Invalid patch_shape {patch_shape} for 3d data.")
153    else:
154        raise RuntimeError(f"Invalid ndim: {ndim}")
155
156    # TODO figure out if with channels
157    ds = torch_em.default_segmentation_dataset(
158        input_paths, input_key, label_paths, label_key,
159        patch_shape=patch_shape, ndim=ndim,
160        label_transform=label_transform,
161        label_transform2=label_transform2,
162    )
163
164    n_cpus = multiprocessing.cpu_count()
165    if perform_split:
166        fractions = [args.train_fraction, 1.0 - args.train_fraction]
167        ds_train, ds_val = _random_split(ds, fractions)
168        train_loader = torch_em.segmentation.get_data_loader(
169            ds_train, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
170        )
171        val_loader = torch_em.segmentation.get_data_loader(
172            ds_val, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
173        )
174        return train_loader, val_loader
175    else:
176        loader = torch_em.segmentation.get_data_loader(
177            ds, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus
178        )
179    return loader
180
181
182def _get_loaders(args, ndim):
183    # if validation data is not passed we split the loader
184    if args.validation_inputs is None:
185        print("You haven't provided validation data so the validation set will be split off the input data.")
186        print(f"A fraction of {args.train_fraction} will be used for training and {1 - args.train_fraction} for val.")
187        train_loader, val_loader = _get_loader(
188            args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key,
189            args=args, ndim=ndim, perform_split=True,
190        )
191    else:
192        train_loader = _get_loader(
193            args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key,
194            args=args, ndim=ndim,
195        )
196        val_loader = _get_loader(
197            args.validation_inputs, args.validation_key, args.validation_labels, args.validation_label_key,
198            args=args, ndim=ndim,
199        )
200    return train_loader, val_loader
201
202
203def _determine_channels(train_loader, args):
204    x, y = next(iter(train_loader))
205    in_channels = x.shape[1]
206    out_channels = y.shape[1]
207    return in_channels, out_channels
208
209
210def train_2d_unet():
211    """@private
212    """
213    parser = _get_training_parser("Train a 2D UNet.")
214    args = parser.parse_args()
215
216    train_loader, val_loader = _get_loaders(args, ndim=2)
217    # TODO more unet settings
218    # create the 2d unet
219    in_channels, out_channels = _determine_channels(train_loader, args)
220    model = UNet2d(in_channels, out_channels)
221
222    if "affinities" in args.label_mode:
223        loss = torch_em.loss.LossWrapper(
224            torch_em.loss.DiceLoss(),
225            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
226        )
227    else:
228        loss = torch_em.loss.DiceLoss()
229
230    # generate a random id for the training
231    name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
232    print("Start 2d unet training for", name)
233    trainer = torch_em.default_segmentation_trainer(
234        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
235        loss=loss, metric=loss, compile_model=False,
236    )
237    trainer.fit(args.n_iterations)
238
239
240def train_3d_unet():
241    """@private
242    """
243    parser = _get_training_parser("Train a 3D UNet.")
244    parser.add_argument("-s", "--scale_factors", type=str,
245                        help="The scale factors for the downsampling factures of the 3D U-Net."
246                        "Can be used to set anisotropic scaling of the U-Net."
247                        "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set"
248                        "anisotropic in the first layer and isotropic scaling in the other two."
249                        "If not passed an isotropic 3D U-Net will be saved.")
250    args = parser.parse_args()
251
252    scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors)
253    train_loader, val_loader = _get_loaders(args, ndim=3)
254
255    # TODO more unet settings
256    # create the 3d unet
257    in_channels, out_channels = _determine_channels(train_loader, args)
258    if scale_factors is None:
259        model = UNet3d(in_channels, out_channels)
260    else:
261        model = AnisotropicUNet(in_channels, out_channels, scale_factors)
262
263    if "affinities" in args.label_mode:
264        loss = torch_em.loss.LossWrapper(
265            torch_em.loss.DiceLoss(),
266            transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply")
267        )
268    else:
269        loss = torch_em.loss.DiceLoss()
270
271    # generate a random id for the training
272    name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name
273    print("Start 3d unet training for", name)
274    trainer = torch_em.default_segmentation_trainer(
275        name=name, model=model, train_loader=train_loader, val_loader=val_loader,
276        loss=loss, metric=loss, compile_model=False,
277    )
278    trainer.fit(args.n_iterations)
279
280
281#
282# CLI for prediction
283#
284
285
286def _get_prediction_parser(description):
287    parser = argparse.ArgumentParser(description=description)
288    parser.add_argument("-c", "--checkpoint", required=True, help="The model checkpoint to use for prediction.")
289    parser.add_argument("-i", "--input_path", required=True,
290                        help="The input path. Supports common image formats (tif, png, etc)"
291                        "as well as container formats like hdf5 and zarr. For the latter 'input_key' is also required.")
292    parser.add_argument("-k", "--input_key", help="The key (path in file) of the input data."
293                        "Required if the input data is a container file format (e.g. hdf5).")
294    parser.add_argument("-o", "--output_path", required=True, help="The path where to save the prediction.")
295    parser.add_argument("--output_key", help="The key for saving the output path. Required for container file formats.")
296    parser.add_argument("-p", "--preprocess", default="standardize")
297    parser.add_argument("--chunks", nargs="+", type=int,
298                        help="The chunks for the serialized prediction. Only relevant for container file formats.")
299    parser.add_argument("--compression", help="The compression to use when saving the prediction.")
300    return parser
301
302
303def _prediction(args, predict, device):
304    model = torch_em.util.get_trainer(args.checkpoint, device=device).model
305
306    if args.input_key is None:
307        input_ = imageio.imread(args.input_path)
308        pred = predict(model, input_)
309    else:
310        with open_file(args.input_path, "r") as f:
311            input_ = f[args.input_key]
312            pred = predict(model, input_)
313
314    output_key = args.output_key
315    if output_key is None:
316        imageio.imwrite(args.output_path, pred)
317    else:
318        kwargs = {}
319        if args.chunks is not None:
320            assert len(args.chunks) == pred.ndim
321            kwargs["chunks"] = args.chunks
322        if args.compression is not None:
323            kwargs["compression"] = args.compression
324        with open_file(args.output_path, "a") as f:
325            ds = f.require_dataset(
326                output_key, shape=pred.shape, dtype=pred.dtype, **kwargs
327            )
328            ds.n_threads = multiprocessing.cpu_count()
329            ds[:] = pred
330
331
332def predict():
333    """@private
334    """
335    parser = _get_prediction_parser("Run prediction (with padding if necessary).")
336    parser.add_argument("--min_divisible", nargs="+", type=int,
337                        help="The minimal divisible factors for the input shape of the models."
338                        "If given the input will be padded to be divisible by these factors.")
339    parser.add_argument("-d", "--device",
340                        help="The device (gpu, cpu) to use for prediction."
341                        "By default a gpu will be used if available, otherwise the cpu will be used.")
342    args = parser.parse_args()
343
344    preprocess = getattr(torch_em.transform.raw, args.preprocess)
345    if args.device is None:
346        device = "cuda" if torch.cuda.is_available() else "cpu"
347    else:
348        device = args.device
349
350    # TODO enable prediction with channels
351    def predict(model, input_):
352        if args.min_divisible is None:
353            with torch.no_grad():
354                input_ = preprocess(input_)
355                input_ = torch.from_numpy(input_[:][None, None]).to(device)
356                pred = model(input_)
357            pred = pred.cpu().numpy().squeeze()
358        else:
359            input_ = preprocess(input_[:])
360            pred = predict_with_padding(input_, model, args.min_divisible, device)
361        return pred
362
363    _prediction(args, predict, device)
364
365
366def _pred_2d(model, input_):
367    assert input_.shape[2] == 1
368    pred = model(input_[:, :, 0])
369    return pred[:, :, None]
370
371
372def predict_with_tiling():
373    """@private
374    """
375    parser = _get_prediction_parser("Run prediction over tiled input.")
376    parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int,
377                        help="The shape of the blocks that will be used to tile the input."
378                        "The model will be applied to each block individually and the results will be stitched.")
379    parser.add_argument("--halo", nargs="+", type=int,
380                        help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.")
381    parser.add_argument("-d", "--devices", nargs="+",
382                        help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus."
383                        "By default a gpu will be used if available, otherwise the cpu will be used.")
384    args = parser.parse_args()
385
386    block_shape = args.block_shape
387    preprocess = getattr(torch_em.transform.raw, args.preprocess)
388    if args.halo is None:
389        halo = [0] * len(block_shape)
390    else:
391        halo = args.halo
392    assert len(halo) == len(block_shape)
393
394    if args.devices is None:
395        devices = ["cuda"] if torch.cuda.is_available() else ["cpu"]
396    else:
397        devices = args.devices
398
399    # if the block shape is singleton in the first axis we assume that this is a 2d model
400    if block_shape[0] == 1:
401        pred_function = _pred_2d
402    else:
403        pred_function = None
404
405    # TODO enable prediction with channels
406    def predict(model, input_):
407        pred = predict_with_halo(
408            input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo,
409            prediction_function=pred_function, preprocess=preprocess
410        )
411        return pred
412
413    _prediction(args, predict, devices[0])