torch_em.cli
1import argparse 2import json 3import multiprocessing 4import uuid 5 6import imageio.v3 as imageio 7import torch 8import torch_em 9from elf.io import open_file 10from torch.utils.data import random_split 11from torch_em.model.unet import AnisotropicUNet, UNet2d, UNet3d 12from torch_em.util.prediction import predict_with_halo, predict_with_padding 13 14 15# 16# CLI for training 17# 18 19 20def _get_training_parser(description): 21 parser = argparse.ArgumentParser(description=description) 22 23 # paths and keys for the data 24 # training inputs and labels are required 25 parser.add_argument("-i", "--training_inputs", 26 help="The input file path(s). Supports common image formats (tif, png, etc)" 27 "as well as container formats like hdf5 and zarr. For the latter 'training_input_key'" 28 "also has to be provided. In case you have a folder with many images you should provide the" 29 "path to the folder instead of individual image paths; for this you then need to provide the" 30 "file pattern (e.g. '*.tif') to 'training_input_key'.", 31 required=True, type=str, nargs="+") 32 parser.add_argument("-l", "--training_labels", 33 help="The label file path(s). See 'training_inputs' for details on the supported formats etc.", 34 required=True, type=str, nargs="+") 35 parser.add_argument("-k", "--training_input_key", 36 help="The key (internal path) for the input data. Required for data formats like hdf5 or zarr.") 37 parser.add_argument("--training_label_key", help="The key for the labels. See also 'training_input_key'") 38 39 # val inputs and labels are optional; if not given we split off parts of the training data 40 parser.add_argument("--validation_inputs", type=str, nargs="+", 41 help="The input file path(s) for validation data. If this is not given" 42 "a fraction of the training inputs will be used for validation.") 43 parser.add_argument("--validation_labels", type=str, nargs="+", 44 help="The label file path(s) for validation. Must be given if 'validation_inputs' are given.") 45 parser.add_argument("--validation_input_key", help="The key for the validation inputs.") 46 parser.add_argument("--validation_label_key", help="The key for the validation labels.") 47 48 # other options 49 parser.add_argument("-b", "--batch_size", type=int, required=True, help="The batch size.") 50 parser.add_argument("-p", "--patch_shape", type=int, nargs="+", required=True, 51 help="The training patch shape") 52 parser.add_argument("-n", "--n_iterations", type=int, default=25000, 53 help="The number of iterations to train for.") 54 parser.add_argument("-m", "--label_mode", 55 help="The label mode determines the transformation applied to the" 56 "labels in order to obtain the targets for training." 57 "This can be used to obtain suitable representations for training given" 58 "instance segmentation ground-truth. Currently supported:" 59 "'affinities', 'affinities_with_foreground'," 60 "'boundaries', 'boundaries_with_foreground', 'foreground'.") 61 parser.add_argument("--name", help="The name of the trained model (checkpoint).") 62 parser.add_argument("--train_fraction", type=float, default=0.8, 63 help="The fraction of the data that will be used for training." 64 "The rest of the data will be used for validation." 65 "This is only used if validation data is not provided," 66 "otherwise all data will be used for training.") 67 68 return parser 69 70 71# TODO provide an option to over-ride the offsets, e.g. via filepath to a json? 72def _get_offsets(ndim, scale_factors): 73 if ndim == 2: 74 offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9], [-27, 0], [0, -27]] 75 elif ndim == 3 and scale_factors is None: 76 offsets = [ 77 [-1, 0, 0], [0, -1, 0], [0, 0, -1], 78 [-3, 0, 0], [0, -3, 0], [0, 0, -3], 79 [-9, 0, 0], [0, -9, 0], [0, 0, -9], 80 [-27, 0, 0], [0, -27, 0], [0, 0, -27], 81 ] 82 else: 83 offsets = [ 84 [-1, 0, 0], [0, -1, 0], [0, 0, -1], 85 [-2, 0, 0], [0, -3, 0], [0, 0, -3], 86 [-3, 0, 0], [0, -9, 0], [0, 0, -9], 87 [-4, 0, 0], [0, -27, 0], [0, 0, -27], 88 ] 89 return offsets 90 91 92# TODO this should be extended to all relevant things, generalized to more datasets and and refactored to torch_em.util 93def _random_split(ds, fractions): 94 95 def _get_attribute(dataset, attr_name): 96 while isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 97 dataset = dataset.datasets[0] 98 attr = getattr(dataset, attr_name) 99 return attr 100 101 ds_train, ds_val = random_split(ds, fractions) 102 103 raw_transform = _get_attribute(ds, "raw_transform") 104 ds_train.raw_transform = raw_transform 105 ds_val.raw_transform = raw_transform 106 107 ndim = _get_attribute(ds, "ndim") 108 ds_train.ndim = ndim 109 ds_val.ndim = ndim 110 111 return ds_train, ds_val 112 113 114def _get_loader(input_paths, input_key, label_paths, label_key, args, ndim, perform_split=False): 115 label_transform, label_transform2 = None, None 116 117 # figure out the label transformations 118 label_modes = ( 119 "affinties", "affinities_and_foreground", 120 "boundaries", "boundaries_and_foreground", 121 "foreground", 122 ) 123 if args.label_mode is None: 124 pass 125 elif args.label_mode == "affinities": 126 offsets = _get_offsets(ndim, args.scale_factors) 127 label_transform = torch_em.transform.label.AffinityTransform( 128 offsets=offsets, add_binary_target=False, add_mask=True, 129 ) 130 elif args.label_mode == "affinities_and_foreground": 131 label_transform = torch_em.transform.label.AffinityTransform( 132 offsets=_get_offsets(ndim, args.scale_factors), add_binary_target=True, add_mask=True, 133 ) 134 elif args.label_mode == "boundaries": 135 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) 136 elif args.label_mode == "boundaries_and_foreground": 137 label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) 138 elif args.label_mode == "foreground": 139 label_transform = torch_em.transform.label.labels_to_binary 140 else: 141 raise ValueError(f"Unknown label mode {args.label_model}, expect one of {label_modes}") 142 143 # validate the patch shape 144 patch_shape = args.patch_shape 145 if ndim == 2: 146 if len(patch_shape) != 2 and patch_shape[0] != 1: 147 raise ValueError(f"Invalid patch_shape {patch_shape} for 2d data.") 148 elif ndim == 3: 149 if len(patch_shape) != 3: 150 raise ValueError(f"Invalid patch_shape {patch_shape} for 3d data.") 151 else: 152 raise RuntimeError(f"Invalid ndim: {ndim}") 153 154 # TODO figure out if with channels 155 ds = torch_em.default_segmentation_dataset( 156 input_paths, input_key, label_paths, label_key, 157 patch_shape=patch_shape, ndim=ndim, 158 label_transform=label_transform, 159 label_transform2=label_transform2, 160 ) 161 162 n_cpus = multiprocessing.cpu_count() 163 if perform_split: 164 fractions = [args.train_fraction, 1.0 - args.train_fraction] 165 ds_train, ds_val = _random_split(ds, fractions) 166 train_loader = torch_em.segmentation.get_data_loader( 167 ds_train, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 168 ) 169 val_loader = torch_em.segmentation.get_data_loader( 170 ds_val, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 171 ) 172 return train_loader, val_loader 173 else: 174 loader = torch_em.segmentation.get_data_loader( 175 ds, batch_size=args.batch_size, shuffle=True, num_workers=n_cpus 176 ) 177 return loader 178 179 180def _get_loaders(args, ndim): 181 # if validation data is not passed we split the loader 182 if args.validation_inputs is None: 183 print("You haven't provided validation data so the validation set will be split off the input data.") 184 print(f"A fraction of {args.train_fraction} will be used for training and {1 - args.train_fraction} for val.") 185 train_loader, val_loader = _get_loader( 186 args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key, 187 args=args, ndim=ndim, perform_split=True, 188 ) 189 else: 190 train_loader = _get_loader( 191 args.training_inputs, args.training_input_key, args.training_labels, args.training_label_key, 192 args=args, ndim=ndim, 193 ) 194 val_loader = _get_loader( 195 args.validation_inputs, args.validation_key, args.validation_labels, args.validation_label_key, 196 args=args, ndim=ndim, 197 ) 198 return train_loader, val_loader 199 200 201def _determine_channels(train_loader, args): 202 x, y = next(iter(train_loader)) 203 in_channels = x.shape[1] 204 out_channels = y.shape[1] 205 return in_channels, out_channels 206 207 208def train_2d_unet(): 209 parser = _get_training_parser("Train a 2D UNet.") 210 args = parser.parse_args() 211 212 train_loader, val_loader = _get_loaders(args, ndim=2) 213 # TODO more unet settings 214 # create the 2d unet 215 in_channels, out_channels = _determine_channels(train_loader, args) 216 model = UNet2d(in_channels, out_channels) 217 218 if "affinities" in args.label_mode: 219 loss = torch_em.loss.LossWrapper( 220 torch_em.loss.DiceLoss(), 221 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 222 ) 223 else: 224 loss = torch_em.loss.DiceLoss() 225 226 # generate a random id for the training 227 name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 228 print("Start 2d unet training for", name) 229 trainer = torch_em.default_segmentation_trainer( 230 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 231 loss=loss, metric=loss, compile_model=False, 232 ) 233 trainer.fit(args.n_iterations) 234 235 236def train_3d_unet(): 237 parser = _get_training_parser("Train a 3D UNet.") 238 parser.add_argument("-s", "--scale_factors", type=str, 239 help="The scale factors for the downsampling factures of the 3D U-Net." 240 "Can be used to set anisotropic scaling of the U-Net." 241 "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set" 242 "anisotropic in the first layer and isotropic scaling in the other two." 243 "If not passed an isotropic 3D U-Net will be saved.") 244 args = parser.parse_args() 245 246 scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors) 247 train_loader, val_loader = _get_loaders(args, ndim=3) 248 249 # TODO more unet settings 250 # create the 3d unet 251 in_channels, out_channels = _determine_channels(train_loader, args) 252 if scale_factors is None: 253 model = UNet3d(in_channels, out_channels) 254 else: 255 model = AnisotropicUNet(in_channels, out_channels, scale_factors) 256 257 if "affinities" in args.label_mode: 258 loss = torch_em.loss.LossWrapper( 259 torch_em.loss.DiceLoss(), 260 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 261 ) 262 else: 263 loss = torch_em.loss.DiceLoss() 264 265 # generate a random id for the training 266 name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 267 print("Start 3d unet training for", name) 268 trainer = torch_em.default_segmentation_trainer( 269 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 270 loss=loss, metric=loss, compile_model=False, 271 ) 272 trainer.fit(args.n_iterations) 273 274 275# 276# CLI for prediction 277# 278 279 280def _get_prediction_parser(description): 281 parser = argparse.ArgumentParser(description=description) 282 parser.add_argument("-c", "--checkpoint", required=True, help="The model checkpoint to use for prediction.") 283 parser.add_argument("-i", "--input_path", required=True, 284 help="The input path. Supports common image formats (tif, png, etc)" 285 "as well as container formats like hdf5 and zarr. For the latter 'input_key' is also required.") 286 parser.add_argument("-k", "--input_key", help="The key (path in file) of the input data." 287 "Required if the input data is a container file format (e.g. hdf5).") 288 parser.add_argument("-o", "--output_path", required=True, help="The path where to save the prediction.") 289 parser.add_argument("--output_key", help="The key for saving the output path. Required for container file formats.") 290 parser.add_argument("-p", "--preprocess", default="standardize") 291 parser.add_argument("--chunks", nargs="+", type=int, 292 help="The chunks for the serialized prediction. Only relevant for container file formats.") 293 parser.add_argument("--compression", help="The compression to use when saving the prediction.") 294 return parser 295 296 297def _prediction(args, predict, device): 298 model = torch_em.util.get_trainer(args.checkpoint, device=device).model 299 300 if args.input_key is None: 301 input_ = imageio.imread(args.input_path) 302 pred = predict(model, input_) 303 else: 304 with open_file(args.input_path, "r") as f: 305 input_ = f[args.input_key] 306 pred = predict(model, input_) 307 308 output_key = args.output_key 309 if output_key is None: 310 imageio.imwrite(args.output_path, pred) 311 else: 312 kwargs = {} 313 if args.chunks is not None: 314 assert len(args.chunks) == pred.ndim 315 kwargs["chunks"] = args.chunks 316 if args.compression is not None: 317 kwargs["compression"] = args.compression 318 with open_file(args.output_path, "a") as f: 319 ds = f.require_dataset( 320 output_key, shape=pred.shape, dtype=pred.dtype, **kwargs 321 ) 322 ds.n_threads = multiprocessing.cpu_count() 323 ds[:] = pred 324 325 326def predict(): 327 parser = _get_prediction_parser("Run prediction (with padding if necessary).") 328 parser.add_argument("--min_divisible", nargs="+", type=int, 329 help="The minimal divisible factors for the input shape of the models." 330 "If given the input will be padded to be divisible by these factors.") 331 parser.add_argument("-d", "--device", 332 help="The device (gpu, cpu) to use for prediction." 333 "By default a gpu will be used if available, otherwise the cpu will be used.") 334 args = parser.parse_args() 335 336 preprocess = getattr(torch_em.transform.raw, args.preprocess) 337 if args.device is None: 338 device = "cuda" if torch.cuda.is_available() else "cpu" 339 else: 340 device = args.device 341 342 # TODO enable prediction with channels 343 def predict(model, input_): 344 if args.min_divisible is None: 345 with torch.no_grad(): 346 input_ = preprocess(input_) 347 input_ = torch.from_numpy(input_[:][None, None]).to(device) 348 pred = model(input_) 349 pred = pred.cpu().numpy().squeeze() 350 else: 351 input_ = preprocess(input_[:]) 352 pred = predict_with_padding(input_, model, args.min_divisible, device) 353 return pred 354 355 _prediction(args, predict, device) 356 357 358def _pred_2d(model, input_): 359 assert input_.shape[2] == 1 360 pred = model(input_[:, :, 0]) 361 return pred[:, :, None] 362 363 364def predict_with_tiling(): 365 parser = _get_prediction_parser("Run prediction over tiled input.") 366 parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int, 367 help="The shape of the blocks that will be used to tile the input." 368 "The model will be applied to each block individually and the results will be stitched.") 369 parser.add_argument("--halo", nargs="+", type=int, 370 help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.") 371 parser.add_argument("-d", "--devices", nargs="+", 372 help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus." 373 "By default a gpu will be used if available, otherwise the cpu will be used.") 374 args = parser.parse_args() 375 376 block_shape = args.block_shape 377 preprocess = getattr(torch_em.transform.raw, args.preprocess) 378 if args.halo is None: 379 halo = [0] * len(block_shape) 380 else: 381 halo = args.halo 382 assert len(halo) == len(block_shape) 383 384 if args.devices is None: 385 devices = ["cuda"] if torch.cuda.is_available() else ["cpu"] 386 else: 387 devices = args.devices 388 389 # if the block shape is singleton in the first axis we assume that this is a 2d model 390 if block_shape[0] == 1: 391 pred_function = _pred_2d 392 else: 393 pred_function = None 394 395 # TODO enable prediction with channels 396 def predict(model, input_): 397 pred = predict_with_halo( 398 input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo, 399 prediction_function=pred_function, preprocess=preprocess 400 ) 401 return pred 402 403 _prediction(args, predict, devices[0])
def
train_2d_unet():
209def train_2d_unet(): 210 parser = _get_training_parser("Train a 2D UNet.") 211 args = parser.parse_args() 212 213 train_loader, val_loader = _get_loaders(args, ndim=2) 214 # TODO more unet settings 215 # create the 2d unet 216 in_channels, out_channels = _determine_channels(train_loader, args) 217 model = UNet2d(in_channels, out_channels) 218 219 if "affinities" in args.label_mode: 220 loss = torch_em.loss.LossWrapper( 221 torch_em.loss.DiceLoss(), 222 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 223 ) 224 else: 225 loss = torch_em.loss.DiceLoss() 226 227 # generate a random id for the training 228 name = f"2d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 229 print("Start 2d unet training for", name) 230 trainer = torch_em.default_segmentation_trainer( 231 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 232 loss=loss, metric=loss, compile_model=False, 233 ) 234 trainer.fit(args.n_iterations)
def
train_3d_unet():
237def train_3d_unet(): 238 parser = _get_training_parser("Train a 3D UNet.") 239 parser.add_argument("-s", "--scale_factors", type=str, 240 help="The scale factors for the downsampling factures of the 3D U-Net." 241 "Can be used to set anisotropic scaling of the U-Net." 242 "Needs to be json encoded, e.g '[[1,2,2],[2,2,2],[2,2,2]]' to set" 243 "anisotropic in the first layer and isotropic scaling in the other two." 244 "If not passed an isotropic 3D U-Net will be saved.") 245 args = parser.parse_args() 246 247 scale_factors = None if args.scale_factors is None else json.loads(args.scale_factors) 248 train_loader, val_loader = _get_loaders(args, ndim=3) 249 250 # TODO more unet settings 251 # create the 3d unet 252 in_channels, out_channels = _determine_channels(train_loader, args) 253 if scale_factors is None: 254 model = UNet3d(in_channels, out_channels) 255 else: 256 model = AnisotropicUNet(in_channels, out_channels, scale_factors) 257 258 if "affinities" in args.label_mode: 259 loss = torch_em.loss.LossWrapper( 260 torch_em.loss.DiceLoss(), 261 transform=torch_em.loss.ApplyAndRemoveMask(masking_method="multiply") 262 ) 263 else: 264 loss = torch_em.loss.DiceLoss() 265 266 # generate a random id for the training 267 name = f"3d-unet-training-{uuid.uuid1()}" if args.name is None else args.name 268 print("Start 3d unet training for", name) 269 trainer = torch_em.default_segmentation_trainer( 270 name=name, model=model, train_loader=train_loader, val_loader=val_loader, 271 loss=loss, metric=loss, compile_model=False, 272 ) 273 trainer.fit(args.n_iterations)
def
predict():
327def predict(): 328 parser = _get_prediction_parser("Run prediction (with padding if necessary).") 329 parser.add_argument("--min_divisible", nargs="+", type=int, 330 help="The minimal divisible factors for the input shape of the models." 331 "If given the input will be padded to be divisible by these factors.") 332 parser.add_argument("-d", "--device", 333 help="The device (gpu, cpu) to use for prediction." 334 "By default a gpu will be used if available, otherwise the cpu will be used.") 335 args = parser.parse_args() 336 337 preprocess = getattr(torch_em.transform.raw, args.preprocess) 338 if args.device is None: 339 device = "cuda" if torch.cuda.is_available() else "cpu" 340 else: 341 device = args.device 342 343 # TODO enable prediction with channels 344 def predict(model, input_): 345 if args.min_divisible is None: 346 with torch.no_grad(): 347 input_ = preprocess(input_) 348 input_ = torch.from_numpy(input_[:][None, None]).to(device) 349 pred = model(input_) 350 pred = pred.cpu().numpy().squeeze() 351 else: 352 input_ = preprocess(input_[:]) 353 pred = predict_with_padding(input_, model, args.min_divisible, device) 354 return pred 355 356 _prediction(args, predict, device)
def
predict_with_tiling():
365def predict_with_tiling(): 366 parser = _get_prediction_parser("Run prediction over tiled input.") 367 parser.add_argument("-b", "--block_shape", nargs="+", required=True, type=int, 368 help="The shape of the blocks that will be used to tile the input." 369 "The model will be applied to each block individually and the results will be stitched.") 370 parser.add_argument("--halo", nargs="+", type=int, 371 help="The overlap of the tiles / blocks used during prediction. By default no overlap is used.") 372 parser.add_argument("-d", "--devices", nargs="+", 373 help="The devices used for prediction. Can either be the cpu, a gpu, or multiple gpus." 374 "By default a gpu will be used if available, otherwise the cpu will be used.") 375 args = parser.parse_args() 376 377 block_shape = args.block_shape 378 preprocess = getattr(torch_em.transform.raw, args.preprocess) 379 if args.halo is None: 380 halo = [0] * len(block_shape) 381 else: 382 halo = args.halo 383 assert len(halo) == len(block_shape) 384 385 if args.devices is None: 386 devices = ["cuda"] if torch.cuda.is_available() else ["cpu"] 387 else: 388 devices = args.devices 389 390 # if the block shape is singleton in the first axis we assume that this is a 2d model 391 if block_shape[0] == 1: 392 pred_function = _pred_2d 393 else: 394 pred_function = None 395 396 # TODO enable prediction with channels 397 def predict(model, input_): 398 pred = predict_with_halo( 399 input_, model, gpu_ids=devices, block_shape=block_shape, halo=halo, 400 prediction_function=pred_function, preprocess=preprocess 401 ) 402 return pred 403 404 _prediction(args, predict, devices[0])