torch_em.util.modelzoo
1import argparse 2import functools 3import json 4import os 5import pickle 6import subprocess 7import tempfile 8 9from glob import glob 10from pathlib import Path 11from typing import Dict, List, Optional, Tuple 12from warnings import warn 13 14import imageio 15import numpy as np 16import torch 17import torch_em 18 19import bioimageio.core as core 20import bioimageio.spec.model.v0_5 as spec 21from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter 22from bioimageio.spec import save_bioimageio_package 23 24from elf.io import open_file 25from .util import get_trainer, get_normalizer 26 27 28# 29# General Purpose Functionality 30# 31 32 33def normalize_with_batch(data, normalizer): 34 """@private 35 """ 36 if normalizer is None: 37 return data 38 normalized = np.concatenate([normalizer(da)[None] for da in data], axis=0) 39 return normalized 40 41 42# 43# Utility Functions for Model Export. 44# 45 46 47def get_default_citations(model=None, model_output=None): 48 """@private 49 """ 50 citations = [{"text": "training library", "doi": "10.5281/zenodo.5108853"}] 51 52 # try to derive the correct network citation from the model class 53 if model is not None: 54 if isinstance(model, str): 55 model_name = model 56 else: 57 model_name = str(model.__class__.__name__) 58 59 if model_name.lower() in ("unet2d", "unet_2d", "unet"): 60 citations.append( 61 {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"} 62 ) 63 elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"): 64 citations.append( 65 {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"} 66 ) 67 else: 68 warn("No citation for architecture {model_name} found.") 69 70 # try to derive the correct segmentation algo citation from the model output type 71 if model_output is not None: 72 msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported." 73 if model_output == "affinities": 74 citations.append( 75 {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"} 76 ) 77 elif model_output == "boundaries": 78 citations.append( 79 {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"} 80 ) 81 else: 82 warn(msg) 83 84 return citations 85 86 87def _get_model(trainer, postprocessing): 88 model = trainer.model 89 model.eval() 90 model_kwargs = model.init_kwargs 91 # clear the kwargs of non builtins 92 # TODO warn if we strip any non-standard arguments 93 model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)} 94 95 # set the in-model postprocessing if given 96 if postprocessing is not None: 97 assert "postprocessing" in model_kwargs 98 model_kwargs["postprocessing"] = postprocessing 99 state = model.state_dict() 100 model = model.__class__(**model_kwargs) 101 model.load_state_dict(state) 102 model.eval() 103 104 return model, model_kwargs 105 106 107def _pad(input_data, trainer): 108 try: 109 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 110 ndim = trainer.train_loader.dataset.dataset.ndim 111 else: 112 ndim = trainer.train_loader.dataset.ndim 113 except AttributeError: 114 ndim = trainer.train_loader.dataset.datasets[0].ndim 115 target_dims = ndim + 2 116 for _ in range(target_dims - input_data.ndim): 117 input_data = np.expand_dims(input_data, axis=0) 118 return input_data 119 120 121def _write_data(input_data, model, trainer, export_folder): 122 # if input_data is None: 123 # gen = SampleGenerator(trainer, 1, False, 1) 124 # input_data = next(gen) 125 if isinstance(input_data, np.ndarray): 126 input_data = [input_data] 127 128 # normalize the input data if we have a normalization function 129 normalizer = get_normalizer(trainer) 130 131 # pad to 4d/5d and normalize the input data 132 # NOTE we have to save the padded data, but without normalization 133 test_inputs = [_pad(input_, trainer) for input_ in input_data] 134 normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs] 135 136 # run prediction 137 with torch.no_grad(): 138 test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized] 139 test_outputs = model(*test_tensors) 140 if torch.is_tensor(test_outputs): 141 test_outputs = [test_outputs] 142 test_outputs = [out.cpu().numpy() for out in test_outputs] 143 144 # save the input / output 145 test_in_paths, test_out_paths = [], [] 146 for i, input_ in enumerate(test_inputs): 147 test_in_path = os.path.join(export_folder, f"test_input_{i}.npy") 148 np.save(test_in_path, input_) 149 test_in_paths.append(test_in_path) 150 for i, out in enumerate(test_outputs): 151 test_out_path = os.path.join(export_folder, f"test_output_{i}.npy") 152 np.save(test_out_path, out) 153 test_out_paths.append(test_out_path) 154 return test_in_paths, test_out_paths 155 156 157def _create_weight_description(model, export_folder, model_kwargs): 158 module = str(model.__class__.__module__) 159 cls_name = str(model.__class__.__name__) 160 161 if module == "torch_em.model.unet": 162 source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py") 163 architecture = spec.ArchitectureFromFileDescr( 164 source=Path(source_path), 165 callable=cls_name, 166 kwargs=model_kwargs, 167 ) 168 else: 169 architecture = spec.ArchitectureFromLibraryDescr( 170 import_from=module, 171 callable=cls_name, 172 kwargs=model_kwargs, 173 ) 174 175 checkpoint_path = os.path.join(export_folder, "state_dict.pt") 176 torch.save(model.state_dict(), checkpoint_path) 177 178 weight_description = spec.WeightsDescr( 179 pytorch_state_dict=spec.PytorchStateDictWeightsDescr( 180 source=Path(checkpoint_path), 181 architecture=architecture, 182 pytorch_version=spec.Version(torch.__version__), 183 ) 184 ) 185 return weight_description 186 187 188def _get_kwargs( 189 trainer, name, description, authors, tags, license, documentation, 190 git_repo, cite, maintainers, export_folder, input_optional_parameters 191): 192 if input_optional_parameters: 193 print("Enter values for the optional parameters.") 194 print("If the default value in [] is satisfactory, press enter without additional input.") 195 print("Please enter lists using json syntax.") 196 197 def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None): 198 # if we don"t have a value, we either ask user for input (offering the default) 199 # or just use the default if input_optional_parameters is False 200 if val is None and input_optional_parameters: 201 default_val = default() 202 choice = input(f"{kwarg_name} [{default_val}]: ") 203 val = choice if choice else default_val 204 elif val is None: 205 val = default() 206 207 if fname is not None: 208 save_path = os.path.join(export_folder, fname) 209 with open(save_path, "w") as f: 210 f.write(val) 211 return save_path 212 213 if is_list and isinstance(val, str): 214 val = val.replace("'", '"') # enable single quotes 215 val = json.loads(val) 216 if is_list: 217 assert isinstance(val, (list, tuple)), type(val) 218 return val 219 220 def _default_authors(): 221 # first try to derive the author name from git 222 try: 223 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 224 author = call_res.stdout.decode("utf8").rstrip("\n") 225 author = author if author else None # in case there was no error, but output is empty 226 except Exception: 227 author = None 228 229 # otherwise use the username 230 if author is None: 231 author = os.uname()[1] 232 233 return [{"name": author}] 234 235 def _default_repo(): 236 return None 237 try: 238 call_res = subprocess.run(["git", "remote", "-v"], capture_output=True) 239 repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1] 240 repo = repo if repo else None 241 except Exception: 242 repo = None 243 return repo 244 245 def _default_maintainers(): 246 # first try to derive the maintainer name from git 247 try: 248 call_res = subprocess.run(["git", "config", "user.name"], capture_output=True) 249 maintainer = call_res.stdout.decode("utf8").rstrip("\n") 250 maintainer = maintainer if maintainer else None # in case there was no error, but output is empty 251 except Exception: 252 maintainer = None 253 254 # otherwise use the username 255 if maintainer is None: 256 maintainer = os.uname()[1] 257 258 return [{"github_user": maintainer}] 259 260 # TODO derive better default values: 261 # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ... 262 # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ... 263 # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ... 264 kwargs = { 265 "name": _get_kwarg("name", name, lambda: trainer.name), 266 "description": _get_kwarg("description", description, lambda: trainer.name), 267 "authors": _get_kwarg("authors", authors, _default_authors, is_list=True), 268 "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True), 269 "license": _get_kwarg("license", license, lambda: "MIT"), 270 "documentation": _get_kwarg( 271 "documentation", documentation, lambda: trainer.name, fname="documentation.md" 272 ), 273 "git_repo": _get_kwarg("git_repo", git_repo, _default_repo), 274 "cite": _get_kwarg("cite", cite, get_default_citations), 275 "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True), 276 } 277 278 return kwargs 279 280 281def _get_preprocessing(trainer): 282 try: 283 if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): 284 ndim = trainer.train_loader.dataset.dataset.ndim 285 else: 286 ndim = trainer.train_loader.dataset.ndim 287 except AttributeError: 288 ndim = trainer.train_loader.dataset.datasets[0].ndim 289 normalizer = get_normalizer(trainer) 290 291 if isinstance(normalizer, functools.partial): 292 kwargs = normalizer.keywords 293 normalizer = normalizer.func 294 else: 295 kwargs = {} 296 297 def _get_axes(axis): 298 all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"] 299 if axis is None: 300 axes = all_axes 301 else: 302 axes = [all_axes[i] for i in axes] 303 return axes 304 305 name = f"{normalizer.__module__}.{normalizer.__name__}" 306 axes = _get_axes(kwargs.get("axis", None)) 307 308 if name == "torch_em.transform.raw.normalize": 309 310 min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None) 311 assert (min_ is None) == (max_ is None) 312 313 if min_ is None: 314 spec_name = "scale_range", 315 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0} 316 else: 317 spec_name = "scale_linear" 318 spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes} 319 320 elif name == "torch_em.transform.raw.standardize": 321 spec_kwargs = {"axes": axes} 322 mean, std = kwargs.get("mean", None), kwargs.get("std", None) 323 if (mean is None) and (std is None): 324 spec_name = "zero_mean_unit_variance" 325 else: 326 spec_name = "fixed_zero_mean_unit_varaince" 327 spec_kwargs.update({"mean": mean, "std": std}) 328 329 elif name == "torch_em.transform.raw.normalize_percentile": 330 lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0) 331 spec_name = "scale_range" 332 spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper} 333 334 else: 335 warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.") 336 return None 337 338 name_to_cls = { 339 "scale_linear": spec.ScaleLinearDescr, 340 "scale_rage": spec.ScaleRangeDescr, 341 "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr, 342 "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr, 343 } 344 preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs) 345 346 return [preprocessing] 347 348 349def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo): 350 351 notebook_link = None 352 module = str(model.__class__.__module__) 353 name = str(model.__class__.__name__) 354 355 # can derive tensor kwargs only for known torch_em models (only unet for now) 356 if module == "torch_em.model.unet": 357 assert len(input_tensors) == len(output_tensors) == 1 358 inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"] 359 360 postprocessing = model_kwargs.get("postprocessing", None) 361 if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"): 362 outc = 1 363 elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"): 364 outc = 2 365 elif postprocessing is not None: 366 warn(f"The model has the post-processing {postprocessing} which cannot be interpreted") 367 368 if name == "UNet2d": 369 depth = model_kwargs["depth"] 370 step = [2 ** depth] * 2 371 if min_shape is None: 372 min_shape = [2 ** (depth + 1)] * 2 373 else: 374 assert len(min_shape) == 2 375 min_shape = list(min_shape) 376 notebook_link = "ilastik/torch-em-2d-unet-notebook" 377 378 elif name == "UNet3d": 379 depth = model_kwargs["depth"] 380 step = [2 ** depth] * 3 381 if min_shape is None: 382 min_shape = [2 ** (depth + 1)] * 3 383 else: 384 assert len(min_shape) == 3 385 min_shape = list(min_shape) 386 notebook_link = "ilastik/torch-em-3d-unet-notebook" 387 388 elif name == "AnisotropicUNet": 389 scale_factors = model_kwargs["scale_factors"] 390 scale_prod = [ 391 int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))])) 392 for d in range(3) 393 ] 394 assert len(scale_prod) == 3 395 step = scale_prod 396 if min_shape is None: 397 min_shape = [2 * sp for sp in scale_prod] 398 else: 399 min_shape = list(min_shape) 400 notebook_link = "ilastik/torch-em-3d-unet-notebook" 401 402 else: 403 raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}") 404 405 if halo is None: # default halo = step // 2 406 halo = [st // 2 for st in step] 407 else: # make sure the passed halo has the same length as step, by padding with zeros 408 halo = [0] * (len(step) - len(halo)) + halo 409 assert len(halo) == len(step), f"{len(halo)}, {len(step)}" 410 411 # Create the input axis description. 412 input_axes = [ 413 spec.BatchAxis(), 414 spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]), 415 ] 416 input_ndim = np.load(input_tensors[0]).ndim 417 assert input_ndim in (4, 5) 418 axis_names = "zyx" if input_ndim == 5 else "yx" 419 assert len(axis_names) == len(min_shape) == len(step) 420 input_axes += [ 421 spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step)) 422 for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step) 423 ] 424 425 # Create the rest of the input description. 426 preprocessing = _get_preprocessing(trainer) 427 input_description = [spec.InputTensorDescr( 428 id=spec.TensorId("image"), 429 axes=input_axes, 430 test_tensor=spec.FileDescr(source=Path(input_tensors[0])), 431 preprocessing=preprocessing, 432 )] 433 434 # Create the output axis description. 435 output_axes = [ 436 spec.BatchAxis(), 437 spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]), 438 ] 439 output_ndim = np.load(output_tensors[0]).ndim 440 assert output_ndim in (4, 5) 441 axis_names = "zyx" if output_ndim == 5 else "yx" 442 assert len(axis_names) == len(halo) 443 output_axes += [ 444 spec.SpaceOutputAxisWithHalo( 445 id=spec.AxisId(ax_name), 446 size=spec.SizeReference( 447 tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name) 448 ), 449 halo=halo_val, 450 ) for ax_name, halo_val in zip(axis_names, halo) 451 ] 452 453 # Create the rest of the output description. 454 output_description = [spec.OutputTensorDescr( 455 id=spec.TensorId("prediction"), 456 axes=output_axes, 457 test_tensor=spec.FileDescr(source=Path(output_tensors[0])) 458 )] 459 460 else: 461 raise NotImplementedError("Model export currently only works for torch_em.model.unet.") 462 463 return input_description, output_description, notebook_link 464 465 466def _validate_model(spec_path): 467 if not os.path.exists(spec_path): 468 return False 469 470 model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True) 471 root = model_spec.root 472 473 input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs] 474 inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths] 475 476 expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs] 477 expected = [np.load(opt) for opt in expected_paths] 478 479 with torch.no_grad(): 480 inputs = [torch.from_numpy(input_) for input_ in inputs] 481 outputs = model(*inputs) 482 if torch.is_tensor(outputs): 483 outputs = [outputs] 484 outputs = [out.numpy() for out in outputs] 485 486 for out, exp in zip(outputs, expected): 487 if not np.allclose(out, exp): 488 return False 489 return True 490 491 492# 493# Model Export Functionality 494# 495 496def _get_input_data(trainer): 497 loader = trainer.val_loader 498 x = next(iter(loader))[0].numpy() 499 return x 500 501 502def export_bioimageio_model( 503 checkpoint: str, 504 output_path: str, 505 input_data: Optional[np.ndarray] = None, 506 name: Optional[str] = None, 507 description: Optional[str] = None, 508 authors: Optional[List[Dict[str, str]]] = None, 509 tags: Optional[List[str]] = None, 510 license: Optional[str] = None, 511 documentation: Optional[str] = None, 512 covers: Optional[str] = None, 513 git_repo: Optional[str] = None, 514 cite: Optional[List[Dict[str, str]]] = None, 515 input_optional_parameters: bool = True, 516 model_postprocessing: Optional[str] = None, 517 for_deepimagej: bool = False, 518 links: Optional[List[str]] = None, 519 maintainers: Optional[List[Dict[str, str]]] = None, 520 min_shape: Tuple[int, ...] = None, 521 halo: Tuple[int, ...] = None, 522 checkpoint_name: str = "best", 523 config: Dict = {}, 524) -> bool: 525 """Export model to bioimage.io model format. 526 527 Args: 528 checkpoint: The path to the checkpoint with the model to export. 529 output_path: The output path for saving the model. 530 input_data: The input data for creating model test data. 531 name: The export name of the model. 532 description: The description of the model. 533 authors: The authors that created this model. 534 tags: List of tags for this model. 535 license: The license under which to publish the model. 536 documentation: The documentation of the model. 537 covers: The covers to show when displaying the model. 538 git_repo: A github repository associated with this model. 539 cite: References to cite for this model. 540 input_optional_parameters: Whether to input optional parameters via the command line. 541 model_postprocessing: Postprocessing to apply to the model outputs. 542 for_deepimagej: Whether this model can be run in DeepImageJ. 543 links: Linked modelzoo apps or software for this model. 544 maintainers: The maintainers of this model. 545 min_shape: The minimal valid input shape for the model. 546 halo: The halo to cut away from model outputs. 547 checkpoint_name: The name of the model checkpoint to load for the export. 548 config: Dictionary with additional configuration for this model. 549 550 Returns: 551 Whether the exported model was successfully validated. 552 """ 553 # Load the trainer and model. 554 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 555 model, model_kwargs = _get_model(trainer, model_postprocessing) 556 557 # Get input data from the trainer if it is not given. 558 if input_data is None: 559 input_data = _get_input_data(trainer) 560 561 with tempfile.TemporaryDirectory() as export_folder: 562 563 # Create the weight description. 564 weight_description = _create_weight_description(model, export_folder, model_kwargs) 565 566 # Create the test input/output files. 567 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 568 # Get the descriptions for inputs, outputs and notebook links. 569 input_description, output_description, notebook_link = _get_inout_descriptions( 570 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 571 ) 572 573 # Get the additional kwargs. 574 kwargs = _get_kwargs( 575 trainer, name, description, 576 authors, tags, license, documentation, 577 git_repo, cite, maintainers, 578 export_folder, input_optional_parameters 579 ) 580 581 # TODO double check the current link policy 582 # The apps to link with this model, by default ilastik. 583 if links is None: 584 links = [] 585 links.append("ilastik/ilastik") 586 # add the notebook link, if available 587 if notebook_link is not None: 588 links.append(notebook_link) 589 kwargs.update({"links": links}) 590 591 if covers is not None: 592 kwargs["covers"] = covers 593 594 model_description = spec.ModelDescr( 595 inputs=input_description, 596 outputs=output_description, 597 weights=weight_description, 598 config=config, 599 **kwargs, 600 ) 601 602 save_bioimageio_package(model_description, output_path=output_path) 603 604 # Validate the model. 605 val_success = _validate_model(output_path) 606 if val_success: 607 print(f"The model was successfully exported to '{output_path}'.") 608 else: 609 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 610 "You can use this model, but it will probably yield incorrect results.") 611 return val_success 612 613 614# TODO support bounding boxes 615def _load_data(path, key): 616 if key is None: 617 ext = os.path.splitext(path)[-1] 618 if ext == ".npy": 619 return np.load(path) 620 else: 621 return imageio.imread(path) 622 else: 623 return open_file(path, mode="r")[key][:] 624 625 626def main(): 627 """@private 628 """ 629 parser = argparse.ArgumentParser( 630 "Export model trained with torch_em to the BioImage.IO model format." 631 "The exported model can be run in any tool supporting BioImage.IO." 632 "For details check out https://bioimage.io/#/." 633 ) 634 parser.add_argument("-p", "--path", required=True, 635 help="Path to the model checkpoint to export to the BioImage.IO model format.") 636 parser.add_argument("-d", "--data", required=True, 637 help="Path to the test data to use for creating the exported model.") 638 parser.add_argument("-f", "--export_folder", required=True, 639 help="Where to save the exported model. The exported model is stored as a zip in the folder.") 640 parser.add_argument("-k", "--key", 641 help="The key for the test data. Required for container data formats like hdf5 or zarr.") 642 parser.add_argument("-n", "--name", help="The name of the exported model.") 643 644 args = parser.parse_args() 645 name = os.path.basename(args.path) if args.name is None else args.name 646 export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name) 647 648 649# 650# model import functionality 651# 652 653def _load_model(model_spec, device): 654 weight_spec = model_spec.weights.pytorch_state_dict 655 model = PytorchModelAdapter.get_network(weight_spec) 656 weight_file = weight_spec.source.path 657 if not os.path.exists(weight_file): 658 root_folder = f"{model_spec.root.filename}.unzip" 659 assert os.path.exists(root_folder), root_folder 660 weight_file = os.path.join(root_folder, weight_file) 661 assert os.path.exists(weight_file), weight_file 662 state = torch.load(weight_file, map_location=device, weights_only=False) 663 model.load_state_dict(state) 664 model.eval() 665 return model 666 667 668def _load_normalizer(model_spec): 669 inputs = model_spec.inputs[0] 670 preprocessing = inputs.preprocessing 671 672 # Filter out ensure dtype. 673 preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"] 674 if len(preprocessing) == 0: 675 return None 676 677 ndim = len(inputs.axes) - 2 678 shape = inputs.shape 679 if hasattr(shape, "min"): 680 shape = shape.min 681 682 conf = preprocessing[0] 683 name = conf.id 684 spec_kwargs = conf.kwargs 685 686 def _get_axis(axes): 687 label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\ 688 {"channel": 0, "y": 1, "x": 2} 689 axis = tuple(label_to_id[ax] for ax in axes) 690 691 # Is the axis full? Then we don't need to specify it. 692 if len(axis) == ndim + 1: 693 return None 694 695 # Drop the channel axis if we have only a single channel. 696 # Because torch_em squeezes the channel axis in this case. 697 if shape[1] == 1: 698 axis = tuple(ax - 1 for ax in axis if ax > 0) 699 return axis 700 701 axis = _get_axis(spec_kwargs.get("axes", None)) 702 if name == "zero_mean_unit_variance": 703 kwargs = {"axis": axis} 704 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 705 706 elif name == "fixed_zero_mean_unit_variance": 707 kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]} 708 normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs) 709 710 elif name == "scale_linear": 711 min_ = -spec_kwargs["offset"] 712 max_ = 1. / spec_kwargs["gain"] 713 kwargs = {"axis": axis, "minval": min_, "maxval": max_} 714 normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs) 715 716 elif name == "scale_range": 717 assert spec_kwargs["mode"] == "per_sample" # Can't parse the other modes right now. 718 lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"] 719 if np.isclose(lower, 0.0) and np.isclose(upper, 100.0): 720 normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis) 721 else: 722 kwargs = {"axis": axis, "lower": lower, "upper": upper} 723 normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs) 724 725 else: 726 msg = f"torch_em does not support the use of the biomageio preprocessing function {name}." 727 raise RuntimeError(msg) 728 729 return normalizer 730 731 732def import_bioimageio_model(spec_path, return_spec=False, device="cpu"): 733 """@private 734 """ 735 model_spec = core.load_description(spec_path) 736 737 model = _load_model(model_spec, device=device) 738 normalizer = _load_normalizer(model_spec) 739 740 if return_spec: 741 return model, normalizer, model_spec 742 else: 743 return model, normalizer 744 745 746# TODO: the weight conversion needs to be updated once the 747# corresponding functionality in bioimageio.core is updated 748# 749# Weight Conversion 750# 751 752 753def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs): 754 with tempfile.TemporaryDirectory() as tmp_dir: 755 weight_path = os.path.join(tmp_dir, weight_name) 756 model_spec = core.load_description(spec_path) 757 weight_descr = converter(model_spec, weight_path, **kwargs) 758 # TODO double check 759 setattr(model_spec.weights, weight_type, weight_descr) 760 save_bioimageio_package(model_spec, output_path=spec_path) 761 762 763def convert_to_onnx(spec_path, opset_version=12): 764 """@private 765 """ 766 raise NotImplementedError 767 # converter = weight_converter.convert_weights_to_onnx 768 # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version) 769 # return None 770 771 772def convert_to_torchscript(model_path): 773 """@private 774 """ 775 raise NotImplementedError 776 # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript 777 778 # weight_name = "weights-torchscript.pt" 779 # breakpoint() 780 # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript") 781 782 # # Check that we can load the converted weights. 783 # model_spec = core.load_description(model_path) 784 # weight_path = model_spec.weights.torchscript.weights 785 # try: 786 # torch.jit.load(weight_path) 787 # return None 788 # except Exception as e: 789 # return e 790 791 792def add_weight_formats(model_path, additional_formats): 793 """@private 794 """ 795 for add_format in additional_formats: 796 797 if add_format == "onnx": 798 ret = convert_to_onnx(model_path) 799 elif add_format == "torchscript": 800 ret = convert_to_torchscript(model_path) 801 802 if ret is None: 803 print("Successfully added", add_format, "weights") 804 else: 805 warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.") 806 807 808def convert_main(): 809 """@private 810 """ 811 parser = argparse.ArgumentParser("Convert weights from native pytorch format to onnx or torchscript") 812 parser.add_argument("-f", "--model_folder", required=True, help="") 813 parser.add_argument("-w", "--weight_format", required=True, help="") 814 args = parser.parse_args() 815 weight_format = args.weight_format 816 assert weight_format in ("onnx", "torchscript") 817 if weight_format == "onnx": 818 convert_to_onnx(args.model_folder) 819 else: 820 convert_to_torchscript(args.model_folder) 821 822 823# 824# Misc Functionality 825# 826 827def export_parser_helper(): 828 """@private 829 """ 830 parser = argparse.ArgumentParser() 831 parser.add_argument("-c", "--checkpoint", required=True) 832 parser.add_argument("-i", "--input", required=True) 833 parser.add_argument("-o", "--output", required=True) 834 parser.add_argument("-a", "--affs_to_bd", default=0, type=int) 835 parser.add_argument("-f", "--additional_formats", type=str, nargs="+") 836 return parser 837 838 839def get_mws_config(offsets, config=None): 840 """@private 841 """ 842 mws_config = {"offsets": offsets} 843 if config is None: 844 config = {"mws": mws_config} 845 else: 846 assert isinstance(config, dict) 847 config["mws"] = mws_config 848 return config 849 850 851def get_shallow2deep_config(rf_path, config=None): 852 """@private 853 """ 854 if os.path.isdir(rf_path): 855 rf_path = glob(os.path.join(rf_path, "*.pkl"))[0] 856 assert os.path.exists(rf_path), rf_path 857 with open(rf_path, "rb") as f: 858 rf = pickle.load(f) 859 shallow2deep_config = {"ndim": rf.feature_ndim, "features": rf.feature_config} 860 if config is None: 861 config = {"shallow2deep": shallow2deep_config} 862 else: 863 assert isinstance(config, dict) 864 config["shallow2deep"] = shallow2deep_config 865 return config
def
export_bioimageio_model( checkpoint: str, output_path: str, input_data: Optional[numpy.ndarray] = None, name: Optional[str] = None, description: Optional[str] = None, authors: Optional[List[Dict[str, str]]] = None, tags: Optional[List[str]] = None, license: Optional[str] = None, documentation: Optional[str] = None, covers: Optional[str] = None, git_repo: Optional[str] = None, cite: Optional[List[Dict[str, str]]] = None, input_optional_parameters: bool = True, model_postprocessing: Optional[str] = None, for_deepimagej: bool = False, links: Optional[List[str]] = None, maintainers: Optional[List[Dict[str, str]]] = None, min_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None, checkpoint_name: str = 'best', config: Dict = {}) -> bool:
503def export_bioimageio_model( 504 checkpoint: str, 505 output_path: str, 506 input_data: Optional[np.ndarray] = None, 507 name: Optional[str] = None, 508 description: Optional[str] = None, 509 authors: Optional[List[Dict[str, str]]] = None, 510 tags: Optional[List[str]] = None, 511 license: Optional[str] = None, 512 documentation: Optional[str] = None, 513 covers: Optional[str] = None, 514 git_repo: Optional[str] = None, 515 cite: Optional[List[Dict[str, str]]] = None, 516 input_optional_parameters: bool = True, 517 model_postprocessing: Optional[str] = None, 518 for_deepimagej: bool = False, 519 links: Optional[List[str]] = None, 520 maintainers: Optional[List[Dict[str, str]]] = None, 521 min_shape: Tuple[int, ...] = None, 522 halo: Tuple[int, ...] = None, 523 checkpoint_name: str = "best", 524 config: Dict = {}, 525) -> bool: 526 """Export model to bioimage.io model format. 527 528 Args: 529 checkpoint: The path to the checkpoint with the model to export. 530 output_path: The output path for saving the model. 531 input_data: The input data for creating model test data. 532 name: The export name of the model. 533 description: The description of the model. 534 authors: The authors that created this model. 535 tags: List of tags for this model. 536 license: The license under which to publish the model. 537 documentation: The documentation of the model. 538 covers: The covers to show when displaying the model. 539 git_repo: A github repository associated with this model. 540 cite: References to cite for this model. 541 input_optional_parameters: Whether to input optional parameters via the command line. 542 model_postprocessing: Postprocessing to apply to the model outputs. 543 for_deepimagej: Whether this model can be run in DeepImageJ. 544 links: Linked modelzoo apps or software for this model. 545 maintainers: The maintainers of this model. 546 min_shape: The minimal valid input shape for the model. 547 halo: The halo to cut away from model outputs. 548 checkpoint_name: The name of the model checkpoint to load for the export. 549 config: Dictionary with additional configuration for this model. 550 551 Returns: 552 Whether the exported model was successfully validated. 553 """ 554 # Load the trainer and model. 555 trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu") 556 model, model_kwargs = _get_model(trainer, model_postprocessing) 557 558 # Get input data from the trainer if it is not given. 559 if input_data is None: 560 input_data = _get_input_data(trainer) 561 562 with tempfile.TemporaryDirectory() as export_folder: 563 564 # Create the weight description. 565 weight_description = _create_weight_description(model, export_folder, model_kwargs) 566 567 # Create the test input/output files. 568 test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder) 569 # Get the descriptions for inputs, outputs and notebook links. 570 input_description, output_description, notebook_link = _get_inout_descriptions( 571 trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo 572 ) 573 574 # Get the additional kwargs. 575 kwargs = _get_kwargs( 576 trainer, name, description, 577 authors, tags, license, documentation, 578 git_repo, cite, maintainers, 579 export_folder, input_optional_parameters 580 ) 581 582 # TODO double check the current link policy 583 # The apps to link with this model, by default ilastik. 584 if links is None: 585 links = [] 586 links.append("ilastik/ilastik") 587 # add the notebook link, if available 588 if notebook_link is not None: 589 links.append(notebook_link) 590 kwargs.update({"links": links}) 591 592 if covers is not None: 593 kwargs["covers"] = covers 594 595 model_description = spec.ModelDescr( 596 inputs=input_description, 597 outputs=output_description, 598 weights=weight_description, 599 config=config, 600 **kwargs, 601 ) 602 603 save_bioimageio_package(model_description, output_path=output_path) 604 605 # Validate the model. 606 val_success = _validate_model(output_path) 607 if val_success: 608 print(f"The model was successfully exported to '{output_path}'.") 609 else: 610 warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " + 611 "You can use this model, but it will probably yield incorrect results.") 612 return val_success
Export model to bioimage.io model format.
Arguments:
- checkpoint: The path to the checkpoint with the model to export.
- output_path: The output path for saving the model.
- input_data: The input data for creating model test data.
- name: The export name of the model.
- description: The description of the model.
- authors: The authors that created this model.
- tags: List of tags for this model.
- license: The license under which to publish the model.
- documentation: The documentation of the model.
- covers: The covers to show when displaying the model.
- git_repo: A github repository associated with this model.
- cite: References to cite for this model.
- input_optional_parameters: Whether to input optional parameters via the command line.
- model_postprocessing: Postprocessing to apply to the model outputs.
- for_deepimagej: Whether this model can be run in DeepImageJ.
- links: Linked modelzoo apps or software for this model.
- maintainers: The maintainers of this model.
- min_shape: The minimal valid input shape for the model.
- halo: The halo to cut away from model outputs.
- checkpoint_name: The name of the model checkpoint to load for the export.
- config: Dictionary with additional configuration for this model.
Returns:
Whether the exported model was successfully validated.