torch_em.cli
@private
1"""@private 2""" 3import argparse 4import json 5import multiprocessing 6import uuid 7 8import imageio.v3 as imageio 9import torch 10import torch_em 11from elf.io import open_file 12from torch.utils.data import random_split 13from torch_em.model.unet import AnisotropicUNet, UNet2d, UNet3d 14from torch_em.util.prediction import predict_with_halo, predict_with_padding 15 16 17# 18# CLI for training 19# 20 21 22def _get_training_parser(description): 23 parser = argparse.ArgumentParser(description=description) 24 25 # paths and keys for the data 26 # training inputs and labels are required 27 parser.add_argument("-i", "--training_inputs", 28 help="The input file path(s). Supports common image formats (tif, png, etc)" 29 "as well as container formats like hdf5 and zarr. For the latter 'training_input_key'" 30 "also has to be provided. In case you have a folder with many images you should provide the" 31 "path to the folder instead of individual image paths; for this you then need to provide the" 32 "file pattern (e.g. '*.tif') to 'training_input_key'.", 33 required=True, type=str, nargs="+") 34 parser.add_argument("-l", "--training_labels", 35 help="The label file path(s). See 'training_inputs' for details on the supported formats etc.", 36 required=True, type=str, nargs="+") 37 parser.add_argument("-k", "--training_input_key", 38 help="The key (internal path) for the input data. Required for data formats like hdf5 or zarr.") 39 parser.add_argument("--training_label_key", help="The key for the labels. See also 'training_input_key'") 40 41 # val inputs and labels are optional; if not given we split off parts of the training data 42 parser.add_argument("--validation_inputs", type=str, nargs="+", 43 help="The input file path(s) for validation data. If this is not given" 44 "a fraction of the training inputs will be used for validation.") 45 parser.add_argument("--validation_labels", type=str, nargs="+", 46 help="The label file path(s) for validation. Must be given if 'validation_inputs' are given.") 47 parser.add_argument("--validation_input_key", help="The key for the validation inputs.") 48 parser.add_argument("--validation_label_key", help="The key for the validation labels.") 49 50 # other options 51 parser.add_argument("-b", "--batch_size", type=int, required=True, help="The batch size.") 52 parser.add_argument("-p", "--patch_shape", type=int, nargs="+", required=True, 53 help="The training patch shape") 54 parser.add_argument("-n", "--n_iterations", type=int, default=25000, 55 help="The number of iterations to train for.") 56 parser.add_argument("-m", "--label_mode", 57 help="The label mode determines the transformation applied to the" 58 "labels in order to obtain the targets for training." 59 "This can be used to obtain suitable representations for training given" 60 "instance segmentation ground-truth. Currently supported:" 61 "'affinities', 'affinities_with_foreground'," 62 "'boundaries', 'boundaries_with_foreground', 'foreground'.") 63 parser.add_argument("--name", help="The name of the trained model (checkpoint).") 64 parser.add_argument("--train_fraction", type=float, default=0.8, 65 help="The fraction of the data that will be used for training." 66 "The rest of the data will be used for validation." 67 "This is only used if validation data is not provided," 68 "otherwise all data will be used for training.") 69 70 return parser 71 72 73# TODO provide an option to over-ride the offsets, e.g. via filepath to a json? 74def _get_offsets(ndim, scale_factors): 75 if ndim == 2: 76 offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9], [-27, 0], [0, -27]] 77 elif ndim == 3 and scale_factors is None: 78 offsets = [ 79 [-1, 0, 0], [0, -1, 0], [0, 0, -1], 80 [-3, 0, 0], [0, -3, 0], [0, 0, -3], 81 [-9, 0, 0], [0, -9, 0], [0, 0, -9], 82 [-27, 0, 0], [0, -27, 0], [0, 0, -27], 83 ] 84 else: 85 offsets = [ 86 [-1, 0, 0], [0, -1, 0], [0, 0, -1], 87 [-2, 0, 0], [0, -3, 0], [0, 0, -3], 88 [-3, 0, 0], [0, -9, 0], [0, 0, -9], 89 [-4, 0, 0], [0, -27, 0], [0, 0, -27], 90 ] 91 return offsets 92 93 94# TODO this should be extended to all relevant things, generalized to more datasets and and refactored to torch_em.util 95def _random_split(ds, fractions): 96 97 def _get_attribute(dataset, attr_name): 98 while isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 99 dataset = dataset.datasets[0] 100 attr = getattr(dataset, attr_name) 101 return attr 102 103 ds_train, ds_val = random_split(ds, fractions) 104 105 raw_transform = _get_attribute(ds, "raw_transform") 106 ds_train.raw_transform = raw_transform 107 ds_val.raw_transform = raw_transform 108 109 ndim = _get_attribute(ds, "ndim") 110 ds_train.ndim = ndim 111 ds_val.ndim = ndim 112 113 return ds_train, ds_val 114 115 116def _get_loader(input_paths, input_key, label_paths, label_key, args, ndim, perform_split=False): 117 label_transform, label_transform2 = None, None 118 119 # figure out the label transformations 120 label_modes = ( 121 "affinties", "affinities_and_foreground", 122 "boundaries", "boundaries_and_foreground", 123 "foreground", 124 ) 125 if args.label_mode is None: 126 pass 127 elif args.label_mode == "affinities": 128 offsets = _get_offsets(ndim, args.scale_factors) 129 label_transform = torch_em.transform.label.AffinityTransform( 130 offsets=offsets, add_binary_target=False, add_mask=True, 131 ) 132 elif args.label_mode == "affinities_and_foreground": 133 label_transform = torch_em.transform.label.AffinityTransform( 134 offsets=_get_offsets(ndim, args.scale_factors), add_binary_target=True, add_mask=True, 135 ) 136 elif args.label_mode == "boundaries": 137 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) 138 elif args.label_mode == "boundaries_and_foreground": 139 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) 140 elif args.label_mode == "foreground": 141 label_transform = torch_em.transform.label.labels_to_binary 142 else: 143 raise ValueError(f"Unknown label mode {args.label_model}, expect one of {label_modes}") 144 145 # validate the patch shape 146 patch_shape = args.patch_shape 147 if ndim == 2: 148 if len(patch_shape) != 2 and patch_shape[0] != 1: 149 raise ValueError(f"Invalid patch_shape {patch_shape} for 2d data.") 150 elif ndim == 3: 151 if len(patch_shape) != 3: 152 raise ValueError(f"Invalid patch_shape {patch_shape} for 3d data.") 153 else: 154 raise RuntimeError(f"Invalid ndim: {ndim}") 155 156 # TODO figure out if with channels 157 ds = torch_em.default_segmentation_dataset( 158 input_paths, input_key, label_paths, label_key, 159 patch_shape=patch_shape, ndim=ndim, 160 label_transform=label_transform, 161 label_transform2=label_transform2, 162 ) 163 164 n_cpus = multiprocessing.cpu_count() 165 if perform_split: 166 fractions = [args.train_fraction, 1.0 - args.train_fraction] 167 ds_train, ds_val = _random_split(ds, fractions) 168 train_loader = torch_em.segmentation.get_data_loader( 169 ds_train, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 170 ) 171 val_loader = torch_em.segmentation.get_data_loader( 172 ds_val, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 173 ) 174 return train_loader, val_loader 175 else: 176 loader = torch_em.segmentation.get_data_loader( 177 ds, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 178 ) 179 return loader 180 181 182def _get_loaders(args, ndim): 183 # if validation data is not passed we split the loader 184 if args.validation_inputs is None: 185 print("You haven't provided validation data so the validation set will be split off the input data.") 186 print(f"A fraction of {args.train_fraction} will be used for training and {1 - args.train_fraction} for val.") 187 train_loader, val_loader = _get_loader( 188 args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key, 189 args=args, ndim=ndim, perform_split=True, 190 ) 191 else: 192 train_loader = _get_loader( 193 args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key, 194 args=args, ndim=ndim, 195 ) 196 val_loader = _get_loader( 197 args.validation_inputs, args.validation_key, args.validation_labels, args.validation_label_key, 198 args=args, ndim=ndim, 199 ) 200 return train_loader, val_loader 201 202 203def _determine_channels(train_loader, args): 204 x, y = next(iter(train_loader)) 205 in_channels = x.shape[1] 206 out_channels = y.shape[1] 207 return in_channels, out_channels 208 209 210def train_2d_unet(): 211 """@private 212 """ 213 parser = _get_training_parser("Train a 2D UNet.") 214 args = parser.parse_args() 215 216 train_loader, val_loader = _get_loaders(args, ndim=2) 217 # TODO more unet settings 218 # create the 2d unet 219 in_channels, out_channels = _determine_channels(train_loader, args) 220 model = UNet2d(in_channels, out_channels) 221 222 if "affinities" in args.label_mode: 223 loss = torch_em.loss.LossWrapper( 224 torch_em.loss.DiceLoss(), 225 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 226 ) 227 else: 228 loss = torch_em.loss.DiceLoss() 229 230 # generate a random id for the training 231 name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 232 print("Start 2d unet training for", name) 233 trainer = torch_em.default_segmentation_trainer( 234 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 235 loss=loss, metric=loss, compile_model=False, 236 ) 237 trainer.fit(args.n_iterations) 238 239 240def train_3d_unet(): 241 """@private 242 """ 243 parser = _get_training_parser("Train a 3D UNet.") 244 parser.add_argument("-s", "--scale_factors", type=str, 245 help="The scale factors for the downsampling factures of the 3D U-Net." 246 "Can be used to set anisotropic scaling of the U-Net." 247 "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set" 248 "anisotropic in the first layer and isotropic scaling in the other two." 249 "If not passed an isotropic 3D U-Net will be saved.") 250 args = parser.parse_args() 251 252 scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors) 253 train_loader, val_loader = _get_loaders(args, ndim=3) 254 255 # TODO more unet settings 256 # create the 3d unet 257 in_channels, out_channels = _determine_channels(train_loader, args) 258 if scale_factors is None: 259 model = UNet3d(in_channels, out_channels) 260 else: 261 model = AnisotropicUNet(in_channels, out_channels, scale_factors) 262 263 if "affinities" in args.label_mode: 264 loss = torch_em.loss.LossWrapper( 265 torch_em.loss.DiceLoss(), 266 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 267 ) 268 else: 269 loss = torch_em.loss.DiceLoss() 270 271 # generate a random id for the training 272 name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 273 print("Start 3d unet training for", name) 274 trainer = torch_em.default_segmentation_trainer( 275 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 276 loss=loss, metric=loss, compile_model=False, 277 ) 278 trainer.fit(args.n_iterations) 279 280 281# 282# CLI for prediction 283# 284 285 286def _get_prediction_parser(description): 287 parser = argparse.ArgumentParser(description=description) 288 parser.add_argument("-c", "--checkpoint", required=True, help="The model checkpoint to use for prediction.") 289 parser.add_argument("-i", "--input_path", required=True, 290 help="The input path. Supports common image formats (tif, png, etc)" 291 "as well as container formats like hdf5 and zarr. For the latter 'input_key' is also required.") 292 parser.add_argument("-k", "--input_key", help="The key (path in file) of the input data." 293 "Required if the input data is a container file format (e.g. hdf5).") 294 parser.add_argument("-o", "--output_path", required=True, help="The path where to save the prediction.") 295 parser.add_argument("--output_key", help="The key for saving the output path. Required for container file formats.") 296 parser.add_argument("-p", "--preprocess", default="standardize") 297 parser.add_argument("--chunks", nargs="+", type=int, 298 help="The chunks for the serialized prediction. Only relevant for container file formats.") 299 parser.add_argument("--compression", help="The compression to use when saving the prediction.") 300 return parser 301 302 303def _prediction(args, predict, device): 304 model = torch_em.util.get_trainer(args.checkpoint, device=device).model 305 306 if args.input_key is None: 307 input_ = imageio.imread(args.input_path) 308 pred = predict(model, input_) 309 else: 310 with open_file(args.input_path, "r") as f: 311 input_ = f[args.input_key] 312 pred = predict(model, input_) 313 314 output_key = args.output_key 315 if output_key is None: 316 imageio.imwrite(args.output_path, pred) 317 else: 318 kwargs = {} 319 if args.chunks is not None: 320 assert len(args.chunks) == pred.ndim 321 kwargs["chunks"] = args.chunks 322 if args.compression is not None: 323 kwargs["compression"] = args.compression 324 with open_file(args.output_path, "a") as f: 325 ds = f.require_dataset( 326 output_key, shape=pred.shape, dtype=pred.dtype, **kwargs 327 ) 328 ds.n_threads = multiprocessing.cpu_count() 329 ds[:] = pred 330 331 332def predict(): 333 """@private 334 """ 335 parser = _get_prediction_parser("Run prediction (with padding if necessary).") 336 parser.add_argument("--min_divisible", nargs="+", type=int, 337 help="The minimal divisible factors for the input shape of the models." 338 "If given the input will be padded to be divisible by these factors.") 339 parser.add_argument("-d", "--device", 340 help="The device (gpu, cpu) to use for prediction." 341 "By default a gpu will be used if available, otherwise the cpu will be used.") 342 args = parser.parse_args() 343 344 preprocess = getattr(torch_em.transform.raw, args.preprocess) 345 if args.device is None: 346 device = "cuda" if torch.cuda.is_available() else "cpu" 347 else: 348 device = args.device 349 350 # TODO enable prediction with channels 351 def predict(model, input_): 352 if args.min_divisible is None: 353 with torch.no_grad(): 354 input_ = preprocess(input_) 355 input_ = torch.from_numpy(input_[:][None, None]).to(device) 356 pred = model(input_) 357 pred = pred.cpu().numpy().squeeze() 358 else: 359 input_ = preprocess(input_[:]) 360 pred = predict_with_padding(input_, model, args.min_divisible, device) 361 return pred 362 363 _prediction(args, predict, device) 364 365 366def _pred_2d(model, input_): 367 assert input_.shape[2] == 1 368 pred = model(input_[:, :, 0]) 369 return pred[:, :, None] 370 371 372def predict_with_tiling(): 373 """@private 374 """ 375 parser = _get_prediction_parser("Run prediction over tiled input.") 376 parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int, 377 help="The shape of the blocks that will be used to tile the input." 378 "The model will be applied to each block individually and the results will be stitched.") 379 parser.add_argument("--halo", nargs="+", type=int, 380 help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.") 381 parser.add_argument("-d", "--devices", nargs="+", 382 help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus." 383 "By default a gpu will be used if available, otherwise the cpu will be used.") 384 args = parser.parse_args() 385 386 block_shape = args.block_shape 387 preprocess = getattr(torch_em.transform.raw, args.preprocess) 388 if args.halo is None: 389 halo = [0] * len(block_shape) 390 else: 391 halo = args.halo 392 assert len(halo) == len(block_shape) 393 394 if args.devices is None: 395 devices = ["cuda"] if torch.cuda.is_available() else ["cpu"] 396 else: 397 devices = args.devices 398 399 # if the block shape is singleton in the first axis we assume that this is a 2d model 400 if block_shape[0] == 1: 401 pred_function = _pred_2d 402 else: 403 pred_function = None 404 405 # TODO enable prediction with channels 406 def predict(model, input_): 407 pred = predict_with_halo( 408 input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo, 409 prediction_function=pred_function, preprocess=preprocess 410 ) 411 return pred 412 413 _prediction(args, predict, devices[0])