torch_em.util.modelzoo
1import argparse 2import functools 3import json 4import os 5import pickle 6import subprocess 7import tempfile 8 9from glob import glob 10from pathlib import Path 11from shutil import unpack_archive 12from typing import Callable, Dict, List, Optional, Tuple 13from warnings import warn 14 15import imageio 16import numpy as np 17import torch 18import torch_em 19 20import bioimageio.core as core 21import bioimageio.spec.model.v0_5 as spec 22from bioimageio.spec import save_bioimageio_package 23from bioimageio.core.backends.pytorch_backend import load_torch_model 24 25from elf.io import open_file 26from .util import get_trainer, get_normalizer 27 28 29# 30# General Purpose Functionality 31# 32 33 34def normalize_with_batch(data, normalizer): 35 """@private 36 """ 37 if normalizer is None: 38 return data 39 normalized = np.concatenate([normalizer(da)[None] for da in data], axis=0) 40 return normalized 41 42 43# 44# Utility Functions for Model Export. 45# 46 47 48def get_default_citations(model=None, model_output=None): 49 """@private 50 """ 51 citations = [{"text": "training library", "doi": "10.5281/zenodo.5108853"}] 52 53 # try to derive the correct network citation from the model class 54 if model is not None: 55 if isinstance(model, str): 56 model_name = model 57 else: 58 model_name = str(model.__class__.__name__) 59 60 if model_name.lower() in ("unet2d", "unet_2d", "unet"): 61 citations.append( 62 {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"} 63 ) 64 elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"): 65 citations.append( 66 {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"} 67 ) 68 else: 69 warn("No citation for architecture {model_name} found.") 70 71 # try to derive the correct segmentation algo citation from the model output type 72 if model_output is not None: 73 msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported." 74 if model_output == "affinities": 75 citations.append( 76 {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"} 77 ) 78 elif model_output == "boundaries": 79 citations.append( 80 {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"} 81 ) 82 else: 83 warn(msg) 84 85 return citations 86 87 88def _get_model(trainer, postprocessing): 89 model = trainer.model 90 model.eval() 91 model_kwargs = model.init_kwargs 92 # clear the kwargs of non builtins 93 # TODO warn if we strip any non-standard arguments 94 model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)} 95 96 # set the in-model postprocessing if given 97 if postprocessing is not None: 98 assert "postprocessing" in model_kwargs 99 model_kwargs["postprocessing"] = postprocessing 100 state = model.state_dict() 101 model = model.__class__(**model_kwargs) 102 model.load_state_dict(state) 103 model.eval() 104 105 return model, model_kwargs 106 107 108def _pad(input_data, trainer): 109 try: 110 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 111 ndim = trainer.train_loader.dataset.dataset.ndim 112 else: 113 ndim = trainer.train_loader.dataset.ndim 114 except AttributeError: 115 ndim = trainer.train_loader.dataset.datasets[0].ndim 116 target_dims = ndim + 2 117 for _ in range(target_dims - input_data.ndim): 118 input_data = np.expand_dims(input_data, axis=0) 119 return input_data 120 121 122def _write_data(input_data, model, trainer, export_folder): 123 # if input_data is None: 124 # gen = SampleGenerator(trainer, 1, False, 1) 125 # input_data = next(gen) 126 if isinstance(input_data, np.ndarray): 127 input_data = [input_data] 128 129 # normalize the input data if we have a normalization function 130 normalizer = get_normalizer(trainer) 131 132 # pad to 4d/5d and normalize the input data 133 # NOTE we have to save the padded data, but without normalization 134 test_inputs = [_pad(input_, trainer) for input_ in input_data] 135 normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs] 136 137 # run prediction 138 with torch.no_grad(): 139 test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized] 140 test_outputs = model(*test_tensors) 141 if torch.is_tensor(test_outputs): 142 test_outputs = [test_outputs] 143 test_outputs = [out.cpu().numpy() for out in test_outputs] 144 145 # save the input / output 146 test_in_paths, test_out_paths = [], [] 147 for i, input_ in enumerate(test_inputs): 148 test_in_path = os.path.join(export_folder, f"test_input_{i}.npy") 149 np.save(test_in_path, input_) 150 test_in_paths.append(test_in_path) 151 for i, out in enumerate(test_outputs): 152 test_out_path = os.path.join(export_folder, f"test_output_{i}.npy") 153 np.save(test_out_path, out) 154 test_out_paths.append(test_out_path) 155 return test_in_paths, test_out_paths 156 157 158def _create_weight_description(model, export_folder, model_kwargs): 159 module = str(model.__class__.__module__) 160 cls_name = str(model.__class__.__name__) 161 162 if module == "torch_em.model.unet": 163 source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py") 164 architecture = spec.ArchitectureFromFileDescr( 165 source=Path(source_path), 166 callable=cls_name, 167 kwargs=model_kwargs, 168 ) 169 else: 170 architecture = spec.ArchitectureFromLibraryDescr( 171 import_from=module, 172 callable=cls_name, 173 kwargs=model_kwargs, 174 ) 175 176 checkpoint_path = os.path.join(export_folder, "state_dict.pt") 177 torch.save(model.state_dict(), checkpoint_path) 178 179 weight_description = spec.WeightsDescr( 180 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 181 source=Path(checkpoint_path), 182 architecture=architecture, 183 pytorch_version=spec.Version(torch.__version__), 184 ) 185 ) 186 return weight_description 187 188 189def _get_kwargs( 190 trainer, name, description, authors, tags, license, documentation, 191 git_repo, cite, maintainers, export_folder, input_optional_parameters 192): 193 if input_optional_parameters: 194 print("Enter values for the optional parameters.") 195 print("If the default value in [] is satisfactory, press enter without additional input.") 196 print("Please enter lists using json syntax.") 197 198 def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None): 199 # if we don"t have a value, we either ask user for input (offering the default) 200 # or just use the default if input_optional_parameters is False 201 if val is None and input_optional_parameters: 202 default_val = default() 203 choice = input(f"{kwarg_name} [{default_val}]: ") 204 val = choice if choice else default_val 205 elif val is None: 206 val = default() 207 208 if fname is not None: 209 save_path = os.path.join(export_folder, fname) 210 with open(save_path, "w") as f: 211 f.write(val) 212 return save_path 213 214 if is_list and isinstance(val, str): 215 val = val.replace("'", '"') # enable single quotes 216 val = json.loads(val) 217 if is_list: 218 assert isinstance(val, (list, tuple)), type(val) 219 return val 220 221 def _default_authors(): 222 # first try to derive the author name from git 223 try: 224 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 225 author = call_res.stdout.decode("utf8").rstrip("\n") 226 author = author if author else None # in case there was no error, but output is empty 227 except Exception: 228 author = None 229 230 # otherwise use the username 231 if author is None: 232 author = os.uname()[1] 233 234 return [{"name": author}] 235 236 def _default_repo(): 237 return None 238 try: 239 call_res = subprocess.run(["git", "remote", "-v"], capture_output=True) 240 repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1] 241 repo = repo if repo else None 242 except Exception: 243 repo = None 244 return repo 245 246 def _default_maintainers(): 247 # first try to derive the maintainer name from git 248 try: 249 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 250 maintainer = call_res.stdout.decode("utf8").rstrip("\n") 251 maintainer = maintainer if maintainer else None # in case there was no error, but output is empty 252 except Exception: 253 maintainer = None 254 255 # otherwise use the username 256 if maintainer is None: 257 maintainer = os.uname()[1] 258 259 return [{"github_user": maintainer}] 260 261 # TODO derive better default values: 262 # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ... 263 # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ... 264 # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ... 265 kwargs = { 266 "name": _get_kwarg("name", name, lambda: trainer.name), 267 "description": _get_kwarg("description", description, lambda: trainer.name), 268 "authors": _get_kwarg("authors", authors, _default_authors, is_list=True), 269 "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True), 270 "license": _get_kwarg("license", license, lambda: "MIT"), 271 "documentation": _get_kwarg( 272 "documentation", documentation, lambda: trainer.name, fname="documentation.md" 273 ), 274 "git_repo": _get_kwarg("git_repo", git_repo, _default_repo), 275 "cite": _get_kwarg("cite", cite, get_default_citations), 276 "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True), 277 } 278 279 return kwargs 280 281 282def _get_preprocessing(trainer): 283 try: 284 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 285 ndim = trainer.train_loader.dataset.dataset.ndim 286 else: 287 ndim = trainer.train_loader.dataset.ndim 288 except AttributeError: 289 ndim = trainer.train_loader.dataset.datasets[0].ndim 290 normalizer = get_normalizer(trainer) 291 292 if isinstance(normalizer, functools.partial): 293 kwargs = normalizer.keywords 294 normalizer = normalizer.func 295 else: 296 kwargs = {} 297 298 def _get_axes(axis): 299 all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"] 300 if axis is None: 301 axes = all_axes 302 else: 303 axes = [all_axes[i] for i in axes] 304 return axes 305 306 name = f"{normalizer.__module__}.{normalizer.__name__}" 307 axes = _get_axes(kwargs.get("axis", None)) 308 309 if name == "torch_em.transform.raw.normalize": 310 311 min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None) 312 assert (min_ is None) == (max_ is None) 313 314 if min_ is None: 315 spec_name = "scale_range", 316 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0} 317 else: 318 spec_name = "scale_linear" 319 spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes} 320 321 elif name == "torch_em.transform.raw.standardize": 322 spec_kwargs = {"axes": axes} 323 mean, std = kwargs.get("mean", None), kwargs.get("std", None) 324 if (mean is None) and (std is None): 325 spec_name = "zero_mean_unit_variance" 326 else: 327 spec_name = "fixed_zero_mean_unit_varaince" 328 spec_kwargs.update({"mean": mean, "std": std}) 329 330 elif name == "torch_em.transform.raw.normalize_percentile": 331 lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0) 332 spec_name = "scale_range" 333 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper} 334 335 else: 336 warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.") 337 return None 338 339 name_to_cls = { 340 "scale_linear": spec.ScaleLinearDescr, 341 "scale_rage": spec.ScaleRangeDescr, 342 "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr, 343 "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr, 344 } 345 preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs) 346 347 return [preprocessing] 348 349 350def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo): 351 352 notebook_link = None 353 module = str(model.__class__.__module__) 354 name = str(model.__class__.__name__) 355 356 # can derive tensor kwargs only for known torch_em models (only unet for now) 357 if module == "torch_em.model.unet": 358 assert len(input_tensors) == len(output_tensors) == 1 359 inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"] 360 361 postprocessing = model_kwargs.get("postprocessing", None) 362 if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"): 363 outc = 1 364 elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"): 365 outc = 2 366 elif postprocessing is not None: 367 warn(f"The model has the post-processing {postprocessing} which cannot be interpreted") 368 369 if name == "UNet2d": 370 depth = model_kwargs["depth"] 371 step = [2 ** depth] * 2 372 if min_shape is None: 373 min_shape = [2 ** (depth + 1)] * 2 374 else: 375 assert len(min_shape) == 2 376 min_shape = list(min_shape) 377 notebook_link = "ilastik/torch-em-2d-unet-notebook" 378 379 elif name == "UNet3d": 380 depth = model_kwargs["depth"] 381 step = [2 ** depth] * 3 382 if min_shape is None: 383 min_shape = [2 ** (depth + 1)] * 3 384 else: 385 assert len(min_shape) == 3 386 min_shape = list(min_shape) 387 notebook_link = "ilastik/torch-em-3d-unet-notebook" 388 389 elif name == "AnisotropicUNet": 390 scale_factors = model_kwargs["scale_factors"] 391 scale_prod = [ 392 int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))])) 393 for d in range(3) 394 ] 395 assert len(scale_prod) == 3 396 step = scale_prod 397 if min_shape is None: 398 min_shape = [2 * sp for sp in scale_prod] 399 else: 400 min_shape = list(min_shape) 401 notebook_link = "ilastik/torch-em-3d-unet-notebook" 402 403 else: 404 raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}") 405 406 if halo is None: # default halo = step // 2 407 halo = [st // 2 for st in step] 408 else: # make sure the passed halo has the same length as step, by padding with zeros 409 halo = [0] * (len(step) - len(halo)) + halo 410 assert len(halo) == len(step), f"{len(halo)}, {len(step)}" 411 412 # Create the input axis description. 413 input_axes = [ 414 spec.BatchAxis(), 415 spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]), 416 ] 417 input_ndim = np.load(input_tensors[0]).ndim 418 assert input_ndim in (4, 5) 419 axis_names = "zyx" if input_ndim == 5 else "yx" 420 assert len(axis_names) == len(min_shape) == len(step) 421 input_axes += [ 422 spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step)) 423 for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step) 424 ] 425 426 # Create the rest of the input description. 427 preprocessing = _get_preprocessing(trainer) 428 input_description = [spec.InputTensorDescr( 429 id=spec.TensorId("image"), 430 axes=input_axes, 431 test_tensor=spec.FileDescr(source=Path(input_tensors[0])), 432 preprocessing=preprocessing, 433 )] 434 435 # Create the output axis description. 436 output_axes = [ 437 spec.BatchAxis(), 438 spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]), 439 ] 440 output_ndim = np.load(output_tensors[0]).ndim 441 assert output_ndim in (4, 5) 442 axis_names = "zyx" if output_ndim == 5 else "yx" 443 assert len(axis_names) == len(halo) 444 output_axes += [ 445 spec.SpaceOutputAxisWithHalo( 446 id=spec.AxisId(ax_name), 447 size=spec.SizeReference( 448 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name) 449 ), 450 halo=halo_val, 451 ) for ax_name, halo_val in zip(axis_names, halo) 452 ] 453 454 # Create the rest of the output description. 455 output_description = [spec.OutputTensorDescr( 456 id=spec.TensorId("prediction"), 457 axes=output_axes, 458 test_tensor=spec.FileDescr(source=Path(output_tensors[0])) 459 )] 460 461 else: 462 raise NotImplementedError("Model export currently only works for torch_em.model.unet.") 463 464 return input_description, output_description, notebook_link 465 466 467def _validate_model(spec_path, output_path): 468 if not os.path.exists(spec_path): 469 return False 470 471 try: 472 model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True, output_path=output_path) 473 root = output_path 474 475 input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs] 476 inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths] 477 478 expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs] 479 expected = [np.load(opt) for opt in expected_paths] 480 481 with torch.no_grad(): 482 inputs = [torch.from_numpy(input_) for input_ in inputs] 483 outputs = model(*inputs) 484 if torch.is_tensor(outputs): 485 outputs = [outputs] 486 outputs = [out.numpy() for out in outputs] 487 488 for out, exp in zip(outputs, expected): 489 if not np.allclose(out, exp): 490 return False 491 492 except Exception as e: 493 print("Model validation failed with the following exception:") 494 print(e) 495 return False 496 497 return True 498 499 500# 501# Model Export Functionality 502# 503 504def _get_input_data(trainer): 505 loader = trainer.val_loader 506 x = next(iter(loader))[0].numpy() 507 return x 508 509 510def export_bioimageio_model( 511 checkpoint: str, 512 output_path: str, 513 input_data: Optional[np.ndarray] = None, 514 name: Optional[str] = None, 515 description: Optional[str] = None, 516 authors: Optional[List[Dict[str, str]]] = None, 517 tags: Optional[List[str]] = None, 518 license: Optional[str] = None, 519 documentation: Optional[str] = None, 520 covers: Optional[str] = None, 521 git_repo: Optional[str] = None, 522 cite: Optional[List[Dict[str, str]]] = None, 523 input_optional_parameters: bool = True, 524 model_postprocessing: Optional[str] = None, 525 for_deepimagej: bool = False, 526 links: Optional[List[str]] = None, 527 maintainers: Optional[List[Dict[str, str]]] = None, 528 min_shape: Tuple[int, ...] = None, 529 halo: Tuple[int, ...] = None, 530 checkpoint_name: str = "best", 531 config: Dict = {}, 532) -> bool: 533 """Export model to bioimage.io model format. 534 535 Args: 536 checkpoint: The path to the checkpoint with the model to export. 537 output_path: The output path for saving the model. 538 input_data: The input data for creating model test data. 539 name: The export name of the model. 540 description: The description of the model. 541 authors: The authors that created this model. 542 tags: List of tags for this model. 543 license: The license under which to publish the model. 544 documentation: The documentation of the model. 545 covers: The covers to show when displaying the model. 546 git_repo: A github repository associated with this model. 547 cite: References to cite for this model. 548 input_optional_parameters: Whether to input optional parameters via the command line. 549 model_postprocessing: Postprocessing to apply to the model outputs. 550 for_deepimagej: Whether this model can be run in DeepImageJ. 551 links: Linked modelzoo apps or software for this model. 552 maintainers: The maintainers of this model. 553 min_shape: The minimal valid input shape for the model. 554 halo: The halo to cut away from model outputs. 555 checkpoint_name: The name of the model checkpoint to load for the export. 556 config: Dictionary with additional configuration for this model. 557 558 Returns: 559 Whether the exported model was successfully validated. 560 """ 561 # Load the trainer and model. 562 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 563 model, model_kwargs = _get_model(trainer, model_postprocessing) 564 565 # Get input data from the trainer if it is not given. 566 if input_data is None: 567 input_data = _get_input_data(trainer) 568 569 with tempfile.TemporaryDirectory() as export_folder: 570 571 # Create the weight description. 572 weight_description = _create_weight_description(model, export_folder, model_kwargs) 573 574 # Create the test input/output files. 575 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 576 # Get the descriptions for inputs, outputs and notebook links. 577 input_description, output_description, notebook_link = _get_inout_descriptions( 578 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 579 ) 580 581 # Get the additional kwargs. 582 kwargs = _get_kwargs( 583 trainer, name, description, 584 authors, tags, license, documentation, 585 git_repo, cite, maintainers, 586 export_folder, input_optional_parameters 587 ) 588 589 # TODO double check the current link policy 590 # The apps to link with this model, by default ilastik. 591 if links is None: 592 links = [] 593 links.append("ilastik/ilastik") 594 # add the notebook link, if available 595 if notebook_link is not None: 596 links.append(notebook_link) 597 kwargs.update({"links": links}) 598 599 if covers is not None: 600 kwargs["covers"] = covers 601 602 model_description = spec.ModelDescr( 603 inputs=input_description, 604 outputs=output_description, 605 weights=weight_description, 606 config=config, 607 **kwargs, 608 ) 609 610 save_bioimageio_package(model_description, output_path=output_path) 611 612 # Validate the model. 613 with tempfile.TemporaryDirectory() as tmp_dir: 614 val_success = _validate_model(output_path, tmp_dir) 615 if val_success: 616 print(f"The model was successfully exported to '{output_path}'.") 617 else: 618 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 619 "You can use this model, but it will probably yield incorrect results.") 620 return val_success 621 622 623# TODO support bounding boxes 624def _load_data(path, key): 625 if key is None: 626 ext = os.path.splitext(path)[-1] 627 if ext == ".npy": 628 return np.load(path) 629 else: 630 return imageio.imread(path) 631 else: 632 return open_file(path, mode="r")[key][:] 633 634 635def main(): 636 """@private 637 """ 638 parser = argparse.ArgumentParser( 639 "Export model trained with torch_em to the BioImage.IO model format." 640 "The exported model can be run in any tool supporting BioImage.IO." 641 "For details check out https://bioimage.io/#/." 642 ) 643 parser.add_argument("-p", "--path", required=True, 644 help="Path to the model checkpoint to export to the BioImage.IO model format.") 645 parser.add_argument("-d", "--data", required=True, 646 help="Path to the test data to use for creating the exported model.") 647 parser.add_argument("-f", "--export_folder", required=True, 648 help="Where to save the exported model. The exported model is stored as a zip in the folder.") 649 parser.add_argument("-k", "--key", 650 help="The key for the test data. Required for container data formats like hdf5 or zarr.") 651 parser.add_argument("-n", "--name", help="The name of the exported model.") 652 653 args = parser.parse_args() 654 name = os.path.basename(args.path) if args.name is None else args.name 655 export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name) 656 657 658# 659# model import functionality 660# 661 662 663def _unzip_and_load_model(model_spec, device, spec_path, output_path): 664 unpack_archive(str(spec_path), str(output_path)) 665 weight_spec = model_spec.weights.pytorch_state_dict 666 model = load_torch_model(weight_spec, devices=[device]) 667 model.eval() 668 return model 669 670 671def _load_model(model_spec, device, spec_path, output_path): 672 if output_path is None: 673 with tempfile.TemporaryDirectory() as tmp_dir: 674 return _unzip_and_load_model(model_spec, device, spec_path, tmp_dir) 675 else: 676 return _unzip_and_load_model(model_spec, device, spec_path, output_path) 677 678 679def _load_normalizer(model_spec): 680 inputs = model_spec.inputs[0] 681 preprocessing = inputs.preprocessing 682 683 # Filter out ensure dtype. 684 preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"] 685 if len(preprocessing) == 0: 686 return None 687 688 ndim = len(inputs.axes) - 2 689 shape = inputs.shape 690 if hasattr(shape, "min"): 691 shape = shape.min 692 693 conf = preprocessing[0] 694 name = conf.id 695 spec_kwargs = conf.kwargs 696 697 def _get_axis(axes): 698 label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\ 699 {"channel": 0, "y": 1, "x": 2} 700 axis = tuple(label_to_id[ax] for ax in axes) 701 702 # Is the axis full? Then we don't need to specify it. 703 if len(axis) == ndim + 1: 704 return None 705 706 # Drop the channel axis if we have only a single channel. 707 # Because torch_em squeezes the channel axis in this case. 708 if shape[1] == 1: 709 axis = tuple(ax - 1 for ax in axis if ax > 0) 710 return axis 711 712 axis = _get_axis(spec_kwargs.get("axes", None)) 713 if name == "zero_mean_unit_variance": 714 kwargs = {"axis": axis} 715 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 716 717 elif name == "fixed_zero_mean_unit_variance": 718 kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]} 719 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 720 721 elif name == "scale_linear": 722 min_ = -spec_kwargs["offset"] 723 max_ = 1. / spec_kwargs["gain"] 724 kwargs = {"axis": axis, "minval": min_, "maxval": max_} 725 normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs) 726 727 elif name == "scale_range": 728 assert spec_kwargs["mode"] == "per_sample" # Can't parse the other modes right now. 729 lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"] 730 if np.isclose(lower, 0.0) and np.isclose(upper, 100.0): 731 normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis) 732 else: 733 kwargs = {"axis": axis, "lower": lower, "upper": upper} 734 normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs) 735 736 else: 737 msg = f"torch_em does not support the use of the biomageio preprocessing function {name}." 738 raise RuntimeError(msg) 739 740 return normalizer 741 742 743def import_bioimageio_model( 744 spec_path: str, 745 return_spec: bool = False, 746 device: Optional[str] = None, 747 output_path: Optional[str] = None, 748) -> Tuple[torch.nn.Module, Callable]: 749 """Import a pytorch model from a bioimageio model. 750 751 Args: 752 spec_path: The path to the bioimageio model. Expects a zipped model file. 753 return_spec: Whether to return the deserialized model description in additon to the model. 754 device: The device to use for loading the model. 755 output_path: The output path for deserializing model files. 756 By default the files will be deserialized to a temporary path. 757 758 Returns: 759 The model loaded as torch.nn.Module. 760 The preprocessing function. 761 """ 762 model_spec = core.load_description(spec_path) 763 764 device = "cuda" if torch.cuda.is_available() else "cpu" 765 model = _load_model(model_spec, device=device, spec_path=spec_path, output_path=output_path) 766 normalizer = _load_normalizer(model_spec) 767 768 if return_spec: 769 return model, normalizer, model_spec 770 else: 771 return model, normalizer 772 773 774# TODO: the weight conversion needs to be updated once the 775# corresponding functionality in bioimageio.core is updated 776# 777# Weight Conversion 778# 779 780 781def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs): 782 with tempfile.TemporaryDirectory() as tmp_dir: 783 weight_path = os.path.join(tmp_dir, weight_name) 784 model_spec = core.load_description(spec_path) 785 weight_descr = converter(model_spec, weight_path, **kwargs) 786 # TODO double check 787 setattr(model_spec.weights, weight_type, weight_descr) 788 save_bioimageio_package(model_spec, output_path=spec_path) 789 790 791def convert_to_onnx(spec_path, opset_version=12): 792 """@private 793 """ 794 raise NotImplementedError 795 # converter = weight_converter.convert_weights_to_onnx 796 # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version) 797 # return None 798 799 800def convert_to_torchscript(model_path): 801 """@private 802 """ 803 raise NotImplementedError 804 # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript 805 806 # weight_name = "weights-torchscript.pt" 807 # breakpoint() 808 # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript") 809 810 # # Check that we can load the converted weights. 811 # model_spec = core.load_description(model_path) 812 # weight_path = model_spec.weights.torchscript.weights 813 # try: 814 # torch.jit.load(weight_path) 815 # return None 816 # except Exception as e: 817 # return e 818 819 820def add_weight_formats(model_path, additional_formats): 821 """@private 822 """ 823 for add_format in additional_formats: 824 825 if add_format == "onnx": 826 ret = convert_to_onnx(model_path) 827 elif add_format == "torchscript": 828 ret = convert_to_torchscript(model_path) 829 830 if ret is None: 831 print("Successfully added", add_format, "weights") 832 else: 833 warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.") 834 835 836def convert_main(): 837 """@private 838 """ 839 parser = argparse.ArgumentParser("Convert weights from native pytorch format to onnx or torchscript") 840 parser.add_argument("-f", "--model_folder", required=True, help="") 841 parser.add_argument("-w", "--weight_format", required=True, help="") 842 args = parser.parse_args() 843 weight_format = args.weight_format 844 assert weight_format in ("onnx", "torchscript") 845 if weight_format == "onnx": 846 convert_to_onnx(args.model_folder) 847 else: 848 convert_to_torchscript(args.model_folder) 849 850 851# 852# Misc Functionality 853# 854 855def export_parser_helper(): 856 """@private 857 """ 858 parser = argparse.ArgumentParser() 859 parser.add_argument("-c", "--checkpoint", required=True) 860 parser.add_argument("-i", "--input", required=True) 861 parser.add_argument("-o", "--output", required=True) 862 parser.add_argument("-a", "--affs_to_bd", default=0, type=int) 863 parser.add_argument("-f", "--additional_formats", type=str, nargs="+") 864 return parser 865 866 867def get_mws_config(offsets, config=None): 868 """@private 869 """ 870 mws_config = {"offsets": offsets} 871 if config is None: 872 config = {"mws": mws_config} 873 else: 874 assert isinstance(config, dict) 875 config["mws"] = mws_config 876 return config 877 878 879def get_shallow2deep_config(rf_path, config=None): 880 """@private 881 """ 882 if os.path.isdir(rf_path): 883 rf_path = glob(os.path.join(rf_path, "*.pkl"))[0] 884 assert os.path.exists(rf_path), rf_path 885 with open(rf_path, "rb") as f: 886 rf = pickle.load(f) 887 shallow2deep_config = {"ndim": rf.feature_ndim, "features": rf.feature_config} 888 if config is None: 889 config = {"shallow2deep": shallow2deep_config} 890 else: 891 assert isinstance(config, dict) 892 config["shallow2deep"] = shallow2deep_config 893 return config
def
export_bioimageio_model( checkpoint: str, output_path: str, input_data: Optional[numpy.ndarray] = None, name: Optional[str] = None, description: Optional[str] = None, authors: Optional[List[Dict[str, str]]] = None, tags: Optional[List[str]] = None, license: Optional[str] = None, documentation: Optional[str] = None, covers: Optional[str] = None, git_repo: Optional[str] = None, cite: Optional[List[Dict[str, str]]] = None, input_optional_parameters: bool = True, model_postprocessing: Optional[str] = None, for_deepimagej: bool = False, links: Optional[List[str]] = None, maintainers: Optional[List[Dict[str, str]]] = None, min_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None, checkpoint_name: str = 'best', config: Dict = {}) -> bool:
511def export_bioimageio_model( 512 checkpoint: str, 513 output_path: str, 514 input_data: Optional[np.ndarray] = None, 515 name: Optional[str] = None, 516 description: Optional[str] = None, 517 authors: Optional[List[Dict[str, str]]] = None, 518 tags: Optional[List[str]] = None, 519 license: Optional[str] = None, 520 documentation: Optional[str] = None, 521 covers: Optional[str] = None, 522 git_repo: Optional[str] = None, 523 cite: Optional[List[Dict[str, str]]] = None, 524 input_optional_parameters: bool = True, 525 model_postprocessing: Optional[str] = None, 526 for_deepimagej: bool = False, 527 links: Optional[List[str]] = None, 528 maintainers: Optional[List[Dict[str, str]]] = None, 529 min_shape: Tuple[int, ...] = None, 530 halo: Tuple[int, ...] = None, 531 checkpoint_name: str = "best", 532 config: Dict = {}, 533) -> bool: 534 """Export model to bioimage.io model format. 535 536 Args: 537 checkpoint: The path to the checkpoint with the model to export. 538 output_path: The output path for saving the model. 539 input_data: The input data for creating model test data. 540 name: The export name of the model. 541 description: The description of the model. 542 authors: The authors that created this model. 543 tags: List of tags for this model. 544 license: The license under which to publish the model. 545 documentation: The documentation of the model. 546 covers: The covers to show when displaying the model. 547 git_repo: A github repository associated with this model. 548 cite: References to cite for this model. 549 input_optional_parameters: Whether to input optional parameters via the command line. 550 model_postprocessing: Postprocessing to apply to the model outputs. 551 for_deepimagej: Whether this model can be run in DeepImageJ. 552 links: Linked modelzoo apps or software for this model. 553 maintainers: The maintainers of this model. 554 min_shape: The minimal valid input shape for the model. 555 halo: The halo to cut away from model outputs. 556 checkpoint_name: The name of the model checkpoint to load for the export. 557 config: Dictionary with additional configuration for this model. 558 559 Returns: 560 Whether the exported model was successfully validated. 561 """ 562 # Load the trainer and model. 563 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 564 model, model_kwargs = _get_model(trainer, model_postprocessing) 565 566 # Get input data from the trainer if it is not given. 567 if input_data is None: 568 input_data = _get_input_data(trainer) 569 570 with tempfile.TemporaryDirectory() as export_folder: 571 572 # Create the weight description. 573 weight_description = _create_weight_description(model, export_folder, model_kwargs) 574 575 # Create the test input/output files. 576 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 577 # Get the descriptions for inputs, outputs and notebook links. 578 input_description, output_description, notebook_link = _get_inout_descriptions( 579 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 580 ) 581 582 # Get the additional kwargs. 583 kwargs = _get_kwargs( 584 trainer, name, description, 585 authors, tags, license, documentation, 586 git_repo, cite, maintainers, 587 export_folder, input_optional_parameters 588 ) 589 590 # TODO double check the current link policy 591 # The apps to link with this model, by default ilastik. 592 if links is None: 593 links = [] 594 links.append("ilastik/ilastik") 595 # add the notebook link, if available 596 if notebook_link is not None: 597 links.append(notebook_link) 598 kwargs.update({"links": links}) 599 600 if covers is not None: 601 kwargs["covers"] = covers 602 603 model_description = spec.ModelDescr( 604 inputs=input_description, 605 outputs=output_description, 606 weights=weight_description, 607 config=config, 608 **kwargs, 609 ) 610 611 save_bioimageio_package(model_description, output_path=output_path) 612 613 # Validate the model. 614 with tempfile.TemporaryDirectory() as tmp_dir: 615 val_success = _validate_model(output_path, tmp_dir) 616 if val_success: 617 print(f"The model was successfully exported to '{output_path}'.") 618 else: 619 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 620 "You can use this model, but it will probably yield incorrect results.") 621 return val_success
Export model to bioimage.io model format.
Arguments:
- checkpoint: The path to the checkpoint with the model to export.
- output_path: The output path for saving the model.
- input_data: The input data for creating model test data.
- name: The export name of the model.
- description: The description of the model.
- authors: The authors that created this model.
- tags: List of tags for this model.
- license: The license under which to publish the model.
- documentation: The documentation of the model.
- covers: The covers to show when displaying the model.
- git_repo: A github repository associated with this model.
- cite: References to cite for this model.
- input_optional_parameters: Whether to input optional parameters via the command line.
- model_postprocessing: Postprocessing to apply to the model outputs.
- for_deepimagej: Whether this model can be run in DeepImageJ.
- links: Linked modelzoo apps or software for this model.
- maintainers: The maintainers of this model.
- min_shape: The minimal valid input shape for the model.
- halo: The halo to cut away from model outputs.
- checkpoint_name: The name of the model checkpoint to load for the export.
- config: Dictionary with additional configuration for this model.
Returns:
Whether the exported model was successfully validated.
def
import_bioimageio_model( spec_path: str, return_spec: bool = False, device: Optional[str] = None, output_path: Optional[str] = None) -> Tuple[torch.nn.modules.module.Module, Callable]:
744def import_bioimageio_model( 745 spec_path: str, 746 return_spec: bool = False, 747 device: Optional[str] = None, 748 output_path: Optional[str] = None, 749) -> Tuple[torch.nn.Module, Callable]: 750 """Import a pytorch model from a bioimageio model. 751 752 Args: 753 spec_path: The path to the bioimageio model. Expects a zipped model file. 754 return_spec: Whether to return the deserialized model description in additon to the model. 755 device: The device to use for loading the model. 756 output_path: The output path for deserializing model files. 757 By default the files will be deserialized to a temporary path. 758 759 Returns: 760 The model loaded as torch.nn.Module. 761 The preprocessing function. 762 """ 763 model_spec = core.load_description(spec_path) 764 765 device = "cuda" if torch.cuda.is_available() else "cpu" 766 model = _load_model(model_spec, device=device, spec_path=spec_path, output_path=output_path) 767 normalizer = _load_normalizer(model_spec) 768 769 if return_spec: 770 return model, normalizer, model_spec 771 else: 772 return model, normalizer
Import a pytorch model from a bioimageio model.
Arguments:
- spec_path: The path to the bioimageio model. Expects a zipped model file.
- return_spec: Whether to return the deserialized model description in additon to the model.
- device: The device to use for loading the model.
- output_path: The output path for deserializing model files. By default the files will be deserialized to a temporary path.
Returns:
The model loaded as torch.nn.Module. The preprocessing function.