torch_em.util.modelzoo
1import argparse 2import functools 3import json 4import os 5import pickle 6import subprocess 7import tempfile 8 9from glob import glob 10from pathlib import Path 11from warnings import warn 12 13import imageio 14import numpy as np 15import torch 16import torch_em 17 18import bioimageio.core as core 19import bioimageio.spec.model.v0_5 as spec 20from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter 21from bioimageio.spec import save_bioimageio_package 22 23from elf.io import open_file 24from .util import get_trainer, get_normalizer 25 26 27# 28# General Purpose Functionality 29# 30 31 32def normalize_with_batch(data, normalizer): 33 if normalizer is None: 34 return data 35 normalized = np.concatenate( 36 [normalizer(da)[None] for da in data], 37 axis=0 38 ) 39 return normalized 40 41 42# 43# Utility Functions for Model Export. 44# 45 46 47def get_default_citations(model=None, model_output=None): 48 citations = [ 49 {"text": "training library", "doi": "10.5281/zenodo.5108853"} 50 ] 51 52 # try to derive the correct network citation from the model class 53 if model is not None: 54 if isinstance(model, str): 55 model_name = model 56 else: 57 model_name = str(model.__class__.__name__) 58 59 if model_name.lower() in ("unet2d", "unet_2d", "unet"): 60 citations.append( 61 {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"} 62 ) 63 elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"): 64 citations.append( 65 {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"} 66 ) 67 else: 68 warn("No citation for architecture {model_name} found.") 69 70 # try to derive the correct segmentation algo citation from the model output type 71 if model_output is not None: 72 msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported." 73 if model_output == "affinities": 74 citations.append( 75 {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"} 76 ) 77 elif model_output == "boundaries": 78 citations.append( 79 {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"} 80 ) 81 else: 82 warn(msg) 83 84 return citations 85 86 87def _get_model(trainer, postprocessing): 88 model = trainer.model 89 model.eval() 90 model_kwargs = model.init_kwargs 91 # clear the kwargs of non builtins 92 # TODO warn if we strip any non-standard arguments 93 model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)} 94 95 # set the in-model postprocessing if given 96 if postprocessing is not None: 97 assert "postprocessing" in model_kwargs 98 model_kwargs["postprocessing"] = postprocessing 99 state = model.state_dict() 100 model = model.__class__(**model_kwargs) 101 model.load_state_dict(state) 102 model.eval() 103 104 return model, model_kwargs 105 106 107def _pad(input_data, trainer): 108 try: 109 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 110 ndim = trainer.train_loader.dataset.dataset.ndim 111 else: 112 ndim = trainer.train_loader.dataset.ndim 113 except AttributeError: 114 ndim = trainer.train_loader.dataset.datasets[0].ndim 115 target_dims = ndim + 2 116 for _ in range(target_dims - input_data.ndim): 117 input_data = np.expand_dims(input_data, axis=0) 118 return input_data 119 120 121def _write_data(input_data, model, trainer, export_folder): 122 # if input_data is None: 123 # gen = SampleGenerator(trainer, 1, False, 1) 124 # input_data = next(gen) 125 if isinstance(input_data, np.ndarray): 126 input_data = [input_data] 127 128 # normalize the input data if we have a normalization function 129 normalizer = get_normalizer(trainer) 130 131 # pad to 4d/5d and normalize the input data 132 # NOTE we have to save the padded data, but without normalization 133 test_inputs = [_pad(input_, trainer) for input_ in input_data] 134 normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs] 135 136 # run prediction 137 with torch.no_grad(): 138 test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized] 139 test_outputs = model(*test_tensors) 140 if torch.is_tensor(test_outputs): 141 test_outputs = [test_outputs] 142 test_outputs = [out.cpu().numpy() for out in test_outputs] 143 144 # save the input / output 145 test_in_paths, test_out_paths = [], [] 146 for i, input_ in enumerate(test_inputs): 147 test_in_path = os.path.join(export_folder, f"test_input_{i}.npy") 148 np.save(test_in_path, input_) 149 test_in_paths.append(test_in_path) 150 for i, out in enumerate(test_outputs): 151 test_out_path = os.path.join(export_folder, f"test_output_{i}.npy") 152 np.save(test_out_path, out) 153 test_out_paths.append(test_out_path) 154 return test_in_paths, test_out_paths 155 156 157def _create_weight_description(model, export_folder, model_kwargs): 158 module = str(model.__class__.__module__) 159 cls_name = str(model.__class__.__name__) 160 161 if module == "torch_em.model.unet": 162 source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py") 163 architecture = spec.ArchitectureFromFileDescr( 164 source=Path(source_path), 165 callable=cls_name, 166 kwargs=model_kwargs, 167 ) 168 else: 169 architecture = spec.ArchitectureFromLibraryDescr( 170 import_from=module, 171 callable=cls_name, 172 kwargs=model_kwargs, 173 ) 174 175 checkpoint_path = os.path.join(export_folder, "state_dict.pt") 176 torch.save(model.state_dict(), checkpoint_path) 177 178 weight_description = spec.WeightsDescr( 179 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 180 source=Path(checkpoint_path), 181 architecture=architecture, 182 pytorch_version=spec.Version(torch.__version__), 183 ) 184 ) 185 return weight_description 186 187 188def _get_kwargs( 189 trainer, name, description, authors, tags, license, documentation, 190 git_repo, cite, maintainers, export_folder, input_optional_parameters 191): 192 if input_optional_parameters: 193 print("Enter values for the optional parameters.") 194 print("If the default value in [] is satisfactory, press enter without additional input.") 195 print("Please enter lists using json syntax.") 196 197 def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None): 198 # if we don"t have a value, we either ask user for input (offering the default) 199 # or just use the default if input_optional_parameters is False 200 if val is None and input_optional_parameters: 201 default_val = default() 202 choice = input(f"{kwarg_name} [{default_val}]: ") 203 val = choice if choice else default_val 204 elif val is None: 205 val = default() 206 207 if fname is not None: 208 save_path = os.path.join(export_folder, fname) 209 with open(save_path, "w") as f: 210 f.write(val) 211 return save_path 212 213 if is_list and isinstance(val, str): 214 val = val.replace(""", """) # enable single quotes 215 val = json.loads(val) 216 if is_list: 217 assert isinstance(val, (list, tuple)), type(val) 218 return val 219 220 def _default_authors(): 221 # first try to derive the author name from git 222 try: 223 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 224 author = call_res.stdout.decode("utf8").rstrip("\n") 225 author = author if author else None # in case there was no error, but output is empty 226 except Exception: 227 author = None 228 229 # otherwise use the username 230 if author is None: 231 author = os.uname()[1] 232 233 return [{"name": author}] 234 235 def _default_repo(): 236 return None 237 try: 238 call_res = subprocess.run(["git", "remote", "-v"], capture_output=True) 239 repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1] 240 repo = repo if repo else None 241 except Exception: 242 repo = None 243 return repo 244 245 def _default_maintainers(): 246 # first try to derive the maintainer name from git 247 try: 248 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 249 maintainer = call_res.stdout.decode("utf8").rstrip("\n") 250 maintainer = maintainer if maintainer else None # in case there was no error, but output is empty 251 except Exception: 252 maintainer = None 253 254 # otherwise use the username 255 if maintainer is None: 256 maintainer = os.uname()[1] 257 258 return [{"github_user": maintainer}] 259 260 # TODO derive better default values: 261 # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ... 262 # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ... 263 # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ... 264 kwargs = { 265 "name": _get_kwarg("name", name, lambda: trainer.name), 266 "description": _get_kwarg("description", description, lambda: trainer.name), 267 "authors": _get_kwarg("authors", authors, _default_authors, is_list=True), 268 "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True), 269 "license": _get_kwarg("license", license, lambda: "MIT"), 270 "documentation": _get_kwarg( 271 "documentation", documentation, lambda: trainer.name, fname="documentation.md" 272 ), 273 "git_repo": _get_kwarg("git_repo", git_repo, _default_repo), 274 "cite": _get_kwarg("cite", cite, get_default_citations), 275 "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True), 276 } 277 278 return kwargs 279 280 281def _get_preprocessing(trainer): 282 try: 283 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 284 ndim = trainer.train_loader.dataset.dataset.ndim 285 else: 286 ndim = trainer.train_loader.dataset.ndim 287 except AttributeError: 288 ndim = trainer.train_loader.dataset.datasets[0].ndim 289 normalizer = get_normalizer(trainer) 290 291 if isinstance(normalizer, functools.partial): 292 kwargs = normalizer.keywords 293 normalizer = normalizer.func 294 else: 295 kwargs = {} 296 297 def _get_axes(axis): 298 all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"] 299 if axis is None: 300 axes = all_axes 301 else: 302 axes = [all_axes[i] for i in axes] 303 return axes 304 305 name = f"{normalizer.__module__}.{normalizer.__name__}" 306 axes = _get_axes(kwargs.get("axis", None)) 307 308 if name == "torch_em.transform.raw.normalize": 309 310 min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None) 311 assert (min_ is None) == (max_ is None) 312 313 if min_ is None: 314 spec_name = "scale_range", 315 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0} 316 else: 317 spec_name = "scale_linear" 318 spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes} 319 320 elif name == "torch_em.transform.raw.standardize": 321 spec_kwargs = {"axes": axes} 322 mean, std = kwargs.get("mean", None), kwargs.get("std", None) 323 if (mean is None) and (std is None): 324 spec_name = "zero_mean_unit_variance" 325 else: 326 spec_name = "fixed_zero_mean_unit_varaince" 327 spec_kwargs.update({"mean": mean, "std": std}) 328 329 elif name == "torch_em.transform.raw.normalize_percentile": 330 lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0) 331 spec_name = "scale_range" 332 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper} 333 334 else: 335 warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.") 336 return None 337 338 name_to_cls = { 339 "scale_linear": spec.ScaleLinearDescr, 340 "scale_rage": spec.ScaleRangeDescr, 341 "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr, 342 "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr, 343 } 344 preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs) 345 346 return [preprocessing] 347 348 349def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo): 350 351 notebook_link = None 352 module = str(model.__class__.__module__) 353 name = str(model.__class__.__name__) 354 355 # can derive tensor kwargs only for known torch_em models (only unet for now) 356 if module == "torch_em.model.unet": 357 assert len(input_tensors) == len(output_tensors) == 1 358 inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"] 359 360 postprocessing = model_kwargs.get("postprocessing", None) 361 if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"): 362 outc = 1 363 elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"): 364 outc = 2 365 elif postprocessing is not None: 366 warn(f"The model has the post-processing {postprocessing} which cannot be interpreted") 367 368 if name == "UNet2d": 369 depth = model_kwargs["depth"] 370 step = [2 ** depth] * 2 371 if min_shape is None: 372 min_shape = [2 ** (depth + 1)] * 2 373 else: 374 assert len(min_shape) == 2 375 min_shape = list(min_shape) 376 notebook_link = "ilastik/torch-em-2d-unet-notebook" 377 378 elif name == "UNet3d": 379 depth = model_kwargs["depth"] 380 step = [2 ** depth] * 3 381 if min_shape is None: 382 min_shape = [2 ** (depth + 1)] * 3 383 else: 384 assert len(min_shape) == 3 385 min_shape = list(min_shape) 386 notebook_link = "ilastik/torch-em-3d-unet-notebook" 387 388 elif name == "AnisotropicUNet": 389 scale_factors = model_kwargs["scale_factors"] 390 scale_prod = [ 391 int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))])) 392 for d in range(3) 393 ] 394 assert len(scale_prod) == 3 395 step = scale_prod 396 if min_shape is None: 397 min_shape = [2 * sp for sp in scale_prod] 398 else: 399 min_shape = list(min_shape) 400 notebook_link = "ilastik/torch-em-3d-unet-notebook" 401 402 else: 403 raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}") 404 405 if halo is None: # default halo = step // 2 406 halo = [st // 2 for st in step] 407 else: # make sure the passed halo has the same length as step, by padding with zeros 408 halo = [0] * (len(step) - len(halo)) + halo 409 assert len(halo) == len(step), f"{len(halo)}, {len(step)}" 410 411 # Create the input axis description. 412 input_axes = [ 413 spec.BatchAxis(), 414 spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]), 415 ] 416 input_ndim = np.load(input_tensors[0]).ndim 417 assert input_ndim in (4, 5) 418 axis_names = "zyx" if input_ndim == 5 else "yx" 419 assert len(axis_names) == len(min_shape) == len(step) 420 input_axes += [ 421 spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step)) 422 for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step) 423 ] 424 425 # Create the rest of the input description. 426 preprocessing = _get_preprocessing(trainer) 427 input_description = [spec.InputTensorDescr( 428 id=spec.TensorId("image"), 429 axes=input_axes, 430 test_tensor=spec.FileDescr(source=Path(input_tensors[0])), 431 preprocessing=preprocessing, 432 )] 433 434 # Create the output axis description. 435 output_axes = [ 436 spec.BatchAxis(), 437 spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]), 438 ] 439 output_ndim = np.load(output_tensors[0]).ndim 440 assert output_ndim in (4, 5) 441 axis_names = "zyx" if output_ndim == 5 else "yx" 442 assert len(axis_names) == len(halo) 443 output_axes += [ 444 spec.SpaceOutputAxisWithHalo( 445 id=spec.AxisId(ax_name), 446 size=spec.SizeReference( 447 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name) 448 ), 449 halo=halo_val, 450 ) for ax_name, halo_val in zip(axis_names, halo) 451 ] 452 453 # Create the rest of the output description. 454 output_description = [spec.OutputTensorDescr( 455 id=spec.TensorId("prediction"), 456 axes=output_axes, 457 test_tensor=spec.FileDescr(source=Path(output_tensors[0])) 458 )] 459 460 else: 461 raise NotImplementedError("Model export currently only works for torch_em.model.unet.") 462 463 return input_description, output_description, notebook_link 464 465 466def _validate_model(spec_path): 467 if not os.path.exists(spec_path): 468 return False 469 470 model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True) 471 root = model_spec.root 472 473 input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs] 474 inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths] 475 476 expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs] 477 expected = [np.load(opt) for opt in expected_paths] 478 479 with torch.no_grad(): 480 inputs = [torch.from_numpy(input_) for input_ in inputs] 481 outputs = model(*inputs) 482 if torch.is_tensor(outputs): 483 outputs = [outputs] 484 outputs = [out.numpy() for out in outputs] 485 486 for out, exp in zip(outputs, expected): 487 if not np.allclose(out, exp): 488 return False 489 return True 490 491 492# 493# Model Export Functionality 494# 495 496def _get_input_data(trainer): 497 loader = trainer.val_loader 498 x = next(iter(loader))[0].numpy() 499 return x 500 501 502# TODO config: training details derived from loss and optimizer, custom params, e.g. offsets for mws 503def export_bioimageio_model( 504 checkpoint, 505 output_path, 506 input_data=None, 507 name=None, 508 description=None, 509 authors=None, 510 tags=None, 511 license=None, 512 documentation=None, 513 covers=None, 514 git_repo=None, 515 cite=None, 516 input_optional_parameters=True, 517 model_postprocessing=None, 518 for_deepimagej=False, 519 links=None, 520 maintainers=None, 521 min_shape=None, 522 halo=None, 523 checkpoint_name="best", 524 training_data=None, 525 config={} 526): 527 """Export model to bioimage.io model format. 528 """ 529 # Load the trainer and model. 530 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 531 model, model_kwargs = _get_model(trainer, model_postprocessing) 532 533 # Get input data from the trainer if it is not given. 534 if input_data is None: 535 input_data = _get_input_data(trainer) 536 537 with tempfile.TemporaryDirectory() as export_folder: 538 539 # Create the weight description. 540 weight_description = _create_weight_description(model, export_folder, model_kwargs) 541 542 # Create the test input/output files. 543 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 544 # Get the descriptions for inputs, outputs and notebook links. 545 input_description, output_description, notebook_link = _get_inout_descriptions( 546 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 547 ) 548 549 # Get the additional kwargs. 550 kwargs = _get_kwargs( 551 trainer, name, description, 552 authors, tags, license, documentation, 553 git_repo, cite, maintainers, 554 export_folder, input_optional_parameters 555 ) 556 557 # TODO double check the current link policy 558 # The apps to link with this model, by default ilastik. 559 if links is None: 560 links = [] 561 links.append("ilastik/ilastik") 562 # add the notebook link, if available 563 if notebook_link is not None: 564 links.append(notebook_link) 565 kwargs.update({"links": links}) 566 567 if covers is not None: 568 kwargs["covers"] = covers 569 570 model_description = spec.ModelDescr( 571 inputs=input_description, 572 outputs=output_description, 573 weights=weight_description, 574 config=config, 575 **kwargs, 576 ) 577 578 save_bioimageio_package(model_description, output_path=output_path) 579 580 # Validate the model. 581 val_success = _validate_model(output_path) 582 if val_success: 583 print(f"The model was successfully exported to '{output_path}'.") 584 else: 585 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 586 "You can use this model, but it will probably yield incorrect results.") 587 return val_success 588 589 590# TODO support bounding boxes 591def _load_data(path, key): 592 if key is None: 593 ext = os.path.splitext(path)[-1] 594 if ext == ".npy": 595 return np.load(path) 596 else: 597 return imageio.imread(path) 598 else: 599 return open_file(path, mode="r")[key][:] 600 601 602def main(): 603 import argparse 604 parser = argparse.ArgumentParser( 605 "Export model trained with torch_em to the BioImage.IO model format." 606 "The exported model can be run in any tool supporting BioImage.IO." 607 "For details check out https://bioimage.io/#/." 608 ) 609 parser.add_argument("-p", "--path", required=True, 610 help="Path to the model checkpoint to export to the BioImage.IO model format.") 611 parser.add_argument("-d", "--data", required=True, 612 help="Path to the test data to use for creating the exported model.") 613 parser.add_argument("-f", "--export_folder", required=True, 614 help="Where to save the exported model. The exported model is stored as a zip in the folder.") 615 parser.add_argument("-k", "--key", 616 help="The key for the test data. Required for container data formats like hdf5 or zarr.") 617 parser.add_argument("-n", "--name", help="The name of the exported model.") 618 619 args = parser.parse_args() 620 name = os.path.basename(args.path) if args.name is None else args.name 621 export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name) 622 623 624# 625# model import functionality 626# 627 628def _load_model(model_spec, device): 629 weight_spec = model_spec.weights.pytorch_state_dict 630 model = PytorchModelAdapter.get_network(weight_spec) 631 weight_file = weight_spec.source.path 632 if not os.path.exists(weight_file): 633 weight_file = os.path.join(model_spec.root, weight_file) 634 assert os.path.exists(weight_file), weight_file 635 state = torch.load(weight_file, map_location=device) 636 model.load_state_dict(state) 637 model.eval() 638 return model 639 640 641def _load_normalizer(model_spec): 642 inputs = model_spec.inputs[0] 643 preprocessing = inputs.preprocessing 644 645 # Filter out ensure dtype. 646 preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"] 647 if len(preprocessing) == 0: 648 return None 649 650 ndim = len(inputs.axes) - 2 651 shape = inputs.shape 652 if hasattr(shape, "min"): 653 shape = shape.min 654 655 conf = preprocessing[0] 656 name = conf.id 657 spec_kwargs = conf.kwargs 658 659 def _get_axis(axes): 660 label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\ 661 {"channel": 0, "y": 1, "x": 2} 662 axis = tuple(label_to_id[ax] for ax in axes) 663 664 # Is the axis full? Then we don't need to specify it. 665 if len(axis) == ndim + 1: 666 return None 667 668 # Drop the channel axis if we have only a single channel. 669 # Because torch_em squeezes the channel axis in this case. 670 if shape[1] == 1: 671 axis = tuple(ax - 1 for ax in axis if ax > 0) 672 return axis 673 674 axis = _get_axis(spec_kwargs.get("axes", None)) 675 if name == "zero_mean_unit_variance": 676 kwargs = {"axis": axis} 677 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 678 679 elif name == "fixed_zero_mean_unit_variance": 680 kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]} 681 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 682 683 elif name == "scale_linear": 684 min_ = -spec_kwargs["offset"] 685 max_ = 1. / spec_kwargs["gain"] 686 kwargs = {"axis": axis, "minval": min_, "maxval": max_} 687 normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs) 688 689 elif name == "scale_range": 690 assert spec_kwargs["mode"] == "per_sample" # Can't parse the other modes right now. 691 lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"] 692 if np.isclose(lower, 0.0) and np.isclose(upper, 100.0): 693 normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis) 694 else: 695 kwargs = {"axis": axis, "lower": lower, "upper": upper} 696 normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs) 697 698 else: 699 msg = f"torch_em does not support the use of the biomageio preprocessing function {name}." 700 raise RuntimeError(msg) 701 702 return normalizer 703 704 705def import_bioimageio_model(spec_path, return_spec=False, device="cpu"): 706 model_spec = core.load_description(spec_path) 707 708 model = _load_model(model_spec, device=device) 709 normalizer = _load_normalizer(model_spec) 710 711 if return_spec: 712 return model, normalizer, model_spec 713 else: 714 return model, normalizer 715 716 717# TODO 718def import_trainer_from_bioimageio_model(spec_path): 719 pass 720 721 722# TODO: the weight conversion needs to be updated once the 723# corresponding functionality in bioimageio.core is updated 724# 725# Weight Conversion 726# 727 728 729def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs): 730 with tempfile.TemporaryDirectory() as tmp_dir: 731 weight_path = os.path.join(tmp_dir, weight_name) 732 model_spec = core.load_description(spec_path) 733 weight_descr = converter(model_spec, weight_path, **kwargs) 734 # TODO double check 735 setattr(model_spec.weights, weight_type, weight_descr) 736 save_bioimageio_package(model_spec, output_path=spec_path) 737 738 739def convert_to_onnx(spec_path, opset_version=12): 740 raise NotImplementedError 741 # converter = weight_converter.convert_weights_to_onnx 742 # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version) 743 # return None 744 745 746def convert_to_torchscript(model_path): 747 raise NotImplementedError 748 # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript 749 750 # weight_name = "weights-torchscript.pt" 751 # breakpoint() 752 # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript") 753 754 # # Check that we can load the converted weights. 755 # model_spec = core.load_description(model_path) 756 # weight_path = model_spec.weights.torchscript.weights 757 # try: 758 # torch.jit.load(weight_path) 759 # return None 760 # except Exception as e: 761 # return e 762 763 764def add_weight_formats(model_path, additional_formats): 765 for add_format in additional_formats: 766 767 if add_format == "onnx": 768 ret = convert_to_onnx(model_path) 769 elif add_format == "torchscript": 770 ret = convert_to_torchscript(model_path) 771 772 if ret is None: 773 print("Successfully added", add_format, "weights") 774 else: 775 warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.") 776 777 778def convert_main(): 779 import argparse 780 parser = argparse.ArgumentParser( 781 "Convert weights from native pytorch format to onnx or torchscript" 782 ) 783 parser.add_argument("-f", "--model_folder", required=True, 784 help="") 785 parser.add_argument("-w", "--weight_format", required=True, 786 help="") 787 args = parser.parse_args() 788 weight_format = args.weight_format 789 assert weight_format in ("onnx", "torchscript") 790 if weight_format == "onnx": 791 convert_to_onnx(args.model_folder) 792 else: 793 convert_to_torchscript(args.model_folder) 794 795 796# 797# Misc Functionality 798# 799 800def export_parser_helper(): 801 parser = argparse.ArgumentParser() 802 parser.add_argument("-c", "--checkpoint", required=True) 803 parser.add_argument("-i", "--input", required=True) 804 parser.add_argument("-o", "--output", required=True) 805 parser.add_argument("-a", "--affs_to_bd", default=0, type=int) 806 parser.add_argument("-f", "--additional_formats", type=str, nargs="+") 807 return parser 808 809 810def get_mws_config(offsets, config=None): 811 mws_config = {"offsets": offsets} 812 if config is None: 813 config = {"mws": mws_config} 814 else: 815 assert isinstance(config, dict) 816 config["mws"] = mws_config 817 return config 818 819 820def get_shallow2deep_config(rf_path, config=None): 821 if os.path.isdir(rf_path): 822 rf_path = glob(os.path.join(rf_path, "*.pkl"))[0] 823 assert os.path.exists(rf_path), rf_path 824 with open(rf_path, "rb") as f: 825 rf = pickle.load(f) 826 shallow2deep_config = { 827 "ndim": rf.feature_ndim, 828 "features": rf.feature_config, 829 } 830 if config is None: 831 config = {"shallow2deep": shallow2deep_config} 832 else: 833 assert isinstance(config, dict) 834 config["shallow2deep"] = shallow2deep_config 835 return config
def
normalize_with_batch(data, normalizer):
def
get_default_citations(model=None, model_output=None):
48def get_default_citations(model=None, model_output=None): 49 citations = [ 50 {"text": "training library", "doi": "10.5281/zenodo.5108853"} 51 ] 52 53 # try to derive the correct network citation from the model class 54 if model is not None: 55 if isinstance(model, str): 56 model_name = model 57 else: 58 model_name = str(model.__class__.__name__) 59 60 if model_name.lower() in ("unet2d", "unet_2d", "unet"): 61 citations.append( 62 {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"} 63 ) 64 elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"): 65 citations.append( 66 {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"} 67 ) 68 else: 69 warn("No citation for architecture {model_name} found.") 70 71 # try to derive the correct segmentation algo citation from the model output type 72 if model_output is not None: 73 msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported." 74 if model_output == "affinities": 75 citations.append( 76 {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"} 77 ) 78 elif model_output == "boundaries": 79 citations.append( 80 {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"} 81 ) 82 else: 83 warn(msg) 84 85 return citations
def
export_bioimageio_model( checkpoint, output_path, input_data=None, name=None, description=None, authors=None, tags=None, license=None, documentation=None, covers=None, git_repo=None, cite=None, input_optional_parameters=True, model_postprocessing=None, for_deepimagej=False, links=None, maintainers=None, min_shape=None, halo=None, checkpoint_name='best', training_data=None, config={}):
504def export_bioimageio_model( 505 checkpoint, 506 output_path, 507 input_data=None, 508 name=None, 509 description=None, 510 authors=None, 511 tags=None, 512 license=None, 513 documentation=None, 514 covers=None, 515 git_repo=None, 516 cite=None, 517 input_optional_parameters=True, 518 model_postprocessing=None, 519 for_deepimagej=False, 520 links=None, 521 maintainers=None, 522 min_shape=None, 523 halo=None, 524 checkpoint_name="best", 525 training_data=None, 526 config={} 527): 528 """Export model to bioimage.io model format. 529 """ 530 # Load the trainer and model. 531 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 532 model, model_kwargs = _get_model(trainer, model_postprocessing) 533 534 # Get input data from the trainer if it is not given. 535 if input_data is None: 536 input_data = _get_input_data(trainer) 537 538 with tempfile.TemporaryDirectory() as export_folder: 539 540 # Create the weight description. 541 weight_description = _create_weight_description(model, export_folder, model_kwargs) 542 543 # Create the test input/output files. 544 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 545 # Get the descriptions for inputs, outputs and notebook links. 546 input_description, output_description, notebook_link = _get_inout_descriptions( 547 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 548 ) 549 550 # Get the additional kwargs. 551 kwargs = _get_kwargs( 552 trainer, name, description, 553 authors, tags, license, documentation, 554 git_repo, cite, maintainers, 555 export_folder, input_optional_parameters 556 ) 557 558 # TODO double check the current link policy 559 # The apps to link with this model, by default ilastik. 560 if links is None: 561 links = [] 562 links.append("ilastik/ilastik") 563 # add the notebook link, if available 564 if notebook_link is not None: 565 links.append(notebook_link) 566 kwargs.update({"links": links}) 567 568 if covers is not None: 569 kwargs["covers"] = covers 570 571 model_description = spec.ModelDescr( 572 inputs=input_description, 573 outputs=output_description, 574 weights=weight_description, 575 config=config, 576 **kwargs, 577 ) 578 579 save_bioimageio_package(model_description, output_path=output_path) 580 581 # Validate the model. 582 val_success = _validate_model(output_path) 583 if val_success: 584 print(f"The model was successfully exported to '{output_path}'.") 585 else: 586 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 587 "You can use this model, but it will probably yield incorrect results.") 588 return val_success
Export model to bioimage.io model format.
def
main():
603def main(): 604 import argparse 605 parser = argparse.ArgumentParser( 606 "Export model trained with torch_em to the BioImage.IO model format." 607 "The exported model can be run in any tool supporting BioImage.IO." 608 "For details check out https://bioimage.io/#/." 609 ) 610 parser.add_argument("-p", "--path", required=True, 611 help="Path to the model checkpoint to export to the BioImage.IO model format.") 612 parser.add_argument("-d", "--data", required=True, 613 help="Path to the test data to use for creating the exported model.") 614 parser.add_argument("-f", "--export_folder", required=True, 615 help="Where to save the exported model. The exported model is stored as a zip in the folder.") 616 parser.add_argument("-k", "--key", 617 help="The key for the test data. Required for container data formats like hdf5 or zarr.") 618 parser.add_argument("-n", "--name", help="The name of the exported model.") 619 620 args = parser.parse_args() 621 name = os.path.basename(args.path) if args.name is None else args.name 622 export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name)
def
import_bioimageio_model(spec_path, return_spec=False, device='cpu'):
706def import_bioimageio_model(spec_path, return_spec=False, device="cpu"): 707 model_spec = core.load_description(spec_path) 708 709 model = _load_model(model_spec, device=device) 710 normalizer = _load_normalizer(model_spec) 711 712 if return_spec: 713 return model, normalizer, model_spec 714 else: 715 return model, normalizer
def
import_trainer_from_bioimageio_model(spec_path):
def
convert_to_onnx(spec_path, opset_version=12):
def
convert_to_torchscript(model_path):
747def convert_to_torchscript(model_path): 748 raise NotImplementedError 749 # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript 750 751 # weight_name = "weights-torchscript.pt" 752 # breakpoint() 753 # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript") 754 755 # # Check that we can load the converted weights. 756 # model_spec = core.load_description(model_path) 757 # weight_path = model_spec.weights.torchscript.weights 758 # try: 759 # torch.jit.load(weight_path) 760 # return None 761 # except Exception as e: 762 # return e
def
add_weight_formats(model_path, additional_formats):
765def add_weight_formats(model_path, additional_formats): 766 for add_format in additional_formats: 767 768 if add_format == "onnx": 769 ret = convert_to_onnx(model_path) 770 elif add_format == "torchscript": 771 ret = convert_to_torchscript(model_path) 772 773 if ret is None: 774 print("Successfully added", add_format, "weights") 775 else: 776 warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.")
def
convert_main():
779def convert_main(): 780 import argparse 781 parser = argparse.ArgumentParser( 782 "Convert weights from native pytorch format to onnx or torchscript" 783 ) 784 parser.add_argument("-f", "--model_folder", required=True, 785 help="") 786 parser.add_argument("-w", "--weight_format", required=True, 787 help="") 788 args = parser.parse_args() 789 weight_format = args.weight_format 790 assert weight_format in ("onnx", "torchscript") 791 if weight_format == "onnx": 792 convert_to_onnx(args.model_folder) 793 else: 794 convert_to_torchscript(args.model_folder)
def
export_parser_helper():
801def export_parser_helper(): 802 parser = argparse.ArgumentParser() 803 parser.add_argument("-c", "--checkpoint", required=True) 804 parser.add_argument("-i", "--input", required=True) 805 parser.add_argument("-o", "--output", required=True) 806 parser.add_argument("-a", "--affs_to_bd", default=0, type=int) 807 parser.add_argument("-f", "--additional_formats", type=str, nargs="+") 808 return parser
def
get_mws_config(offsets, config=None):
def
get_shallow2deep_config(rf_path, config=None):
821def get_shallow2deep_config(rf_path, config=None): 822 if os.path.isdir(rf_path): 823 rf_path = glob(os.path.join(rf_path, "*.pkl"))[0] 824 assert os.path.exists(rf_path), rf_path 825 with open(rf_path, "rb") as f: 826 rf = pickle.load(f) 827 shallow2deep_config = { 828 "ndim": rf.feature_ndim, 829 "features": rf.feature_config, 830 } 831 if config is None: 832 config = {"shallow2deep": shallow2deep_config} 833 else: 834 assert isinstance(config, dict) 835 config["shallow2deep"] = shallow2deep_config 836 return config