torch_em.util.modelzoo

  1import argparse
  2import functools
  3import json
  4import os
  5import pickle
  6import subprocess
  7import tempfile
  8
  9from glob import glob
 10from pathlib import Path
 11from shutil import unpack_archive
 12from typing import Callable, Dict, List, Optional, Tuple
 13from warnings import warn
 14
 15import imageio
 16import numpy as np
 17import torch
 18import torch_em
 19
 20import bioimageio.core as core
 21import bioimageio.spec.model.v0_5 as spec
 22from bioimageio.spec import save_bioimageio_package
 23from bioimageio.core.backends.pytorch_backend import load_torch_model
 24
 25from elf.io import open_file
 26from .util import get_trainer, get_normalizer
 27
 28
 29#
 30# General Purpose Functionality
 31#
 32
 33
 34def normalize_with_batch(data, normalizer):
 35    """@private
 36    """
 37    if normalizer is None:
 38        return data
 39    normalized = np.concatenate([normalizer(da)[None] for da in data], axis=0)
 40    return normalized
 41
 42
 43#
 44# Utility Functions for Model Export.
 45#
 46
 47
 48def get_default_citations(model=None, model_output=None):
 49    """@private
 50    """
 51    citations = [{"text": "training library", "doi": "10.5281/zenodo.5108853"}]
 52
 53    # try to derive the correct network citation from the model class
 54    if model is not None:
 55        if isinstance(model, str):
 56            model_name = model
 57        else:
 58            model_name = str(model.__class__.__name__)
 59
 60        if model_name.lower() in ("unet2d", "unet_2d", "unet"):
 61            citations.append(
 62                {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"}
 63            )
 64        elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"):
 65            citations.append(
 66                {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"}
 67            )
 68        else:
 69            warn("No citation for architecture {model_name} found.")
 70
 71    # try to derive the correct segmentation algo citation from the model output type
 72    if model_output is not None:
 73        msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported."
 74        if model_output == "affinities":
 75            citations.append(
 76                {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"}
 77            )
 78        elif model_output == "boundaries":
 79            citations.append(
 80                {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"}
 81            )
 82        else:
 83            warn(msg)
 84
 85    return citations
 86
 87
 88def _get_model(trainer, postprocessing):
 89    model = trainer.model
 90    model.eval()
 91    model_kwargs = model.init_kwargs
 92    # clear the kwargs of non builtins
 93    # TODO warn if we strip any non-standard arguments
 94    model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)}
 95
 96    # set the in-model postprocessing if given
 97    if postprocessing is not None:
 98        assert "postprocessing" in model_kwargs
 99        model_kwargs["postprocessing"] = postprocessing
100        state = model.state_dict()
101        model = model.__class__(**model_kwargs)
102        model.load_state_dict(state)
103        model.eval()
104
105    return model, model_kwargs
106
107
108def _pad(input_data, trainer):
109    try:
110        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
111            ndim = trainer.train_loader.dataset.dataset.ndim
112        else:
113            ndim = trainer.train_loader.dataset.ndim
114    except AttributeError:
115        ndim = trainer.train_loader.dataset.datasets[0].ndim
116    target_dims = ndim + 2
117    for _ in range(target_dims - input_data.ndim):
118        input_data = np.expand_dims(input_data, axis=0)
119    return input_data
120
121
122def _write_data(input_data, model, trainer, export_folder):
123    # if input_data is None:
124    #     gen = SampleGenerator(trainer, 1, False, 1)
125    #     input_data = next(gen)
126    if isinstance(input_data, np.ndarray):
127        input_data = [input_data]
128
129    # normalize the input data if we have a normalization function
130    normalizer = get_normalizer(trainer)
131
132    # pad to 4d/5d and normalize the input data
133    # NOTE we have to save the padded data, but without normalization
134    test_inputs = [_pad(input_, trainer) for input_ in input_data]
135    normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs]
136
137    # run prediction
138    with torch.no_grad():
139        test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized]
140        test_outputs = model(*test_tensors)
141        if torch.is_tensor(test_outputs):
142            test_outputs = [test_outputs]
143        test_outputs = [out.cpu().numpy() for out in test_outputs]
144
145    # save the input / output
146    test_in_paths, test_out_paths = [], []
147    for i, input_ in enumerate(test_inputs):
148        test_in_path = os.path.join(export_folder, f"test_input_{i}.npy")
149        np.save(test_in_path, input_)
150        test_in_paths.append(test_in_path)
151    for i, out in enumerate(test_outputs):
152        test_out_path = os.path.join(export_folder, f"test_output_{i}.npy")
153        np.save(test_out_path, out)
154        test_out_paths.append(test_out_path)
155    return test_in_paths, test_out_paths
156
157
158def _create_weight_description(model, export_folder, model_kwargs):
159    module = str(model.__class__.__module__)
160    cls_name = str(model.__class__.__name__)
161
162    if module == "torch_em.model.unet":
163        source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py")
164        architecture = spec.ArchitectureFromFileDescr(
165            source=Path(source_path),
166            callable=cls_name,
167            kwargs=model_kwargs,
168        )
169    else:
170        architecture = spec.ArchitectureFromLibraryDescr(
171            import_from=module,
172            callable=cls_name,
173            kwargs=model_kwargs,
174        )
175
176    checkpoint_path = os.path.join(export_folder, "state_dict.pt")
177    torch.save(model.state_dict(), checkpoint_path)
178
179    weight_description = spec.WeightsDescr(
180        pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
181            source=Path(checkpoint_path),
182            architecture=architecture,
183            pytorch_version=spec.Version(torch.__version__),
184        )
185    )
186    return weight_description
187
188
189def _get_kwargs(
190    trainer, name, description, authors, tags, license, documentation,
191    git_repo, cite, maintainers, export_folder, input_optional_parameters
192):
193    if input_optional_parameters:
194        print("Enter values for the optional parameters.")
195        print("If the default value in [] is satisfactory, press enter without additional input.")
196        print("Please enter lists using json syntax.")
197
198    def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None):
199        # if we don"t have a value, we either ask user for input (offering the default)
200        # or just use the default if input_optional_parameters is False
201        if val is None and input_optional_parameters:
202            default_val = default()
203            choice = input(f"{kwarg_name} [{default_val}]: ")
204            val = choice if choice else default_val
205        elif val is None:
206            val = default()
207
208        if fname is not None:
209            save_path = os.path.join(export_folder, fname)
210            with open(save_path, "w") as f:
211                f.write(val)
212            return save_path
213
214        if is_list and isinstance(val, str):
215            val = val.replace("'", '"')  # enable single quotes
216            val = json.loads(val)
217        if is_list:
218            assert isinstance(val, (list, tuple)), type(val)
219        return val
220
221    def _default_authors():
222        # first try to derive the author name from git
223        try:
224            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
225            author = call_res.stdout.decode("utf8").rstrip("\n")
226            author = author if author else None  # in case there was no error, but output is empty
227        except Exception:
228            author = None
229
230        # otherwise use the username
231        if author is None:
232            author = os.uname()[1]
233
234        return [{"name": author}]
235
236    def _default_repo():
237        return None
238        try:
239            call_res = subprocess.run(["git", "remote", "-v"], capture_output=True)
240            repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1]
241            repo = repo if repo else None
242        except Exception:
243            repo = None
244        return repo
245
246    def _default_maintainers():
247        # first try to derive the maintainer name from git
248        try:
249            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
250            maintainer = call_res.stdout.decode("utf8").rstrip("\n")
251            maintainer = maintainer if maintainer else None  # in case there was no error, but output is empty
252        except Exception:
253            maintainer = None
254
255        # otherwise use the username
256        if maintainer is None:
257            maintainer = os.uname()[1]
258
259        return [{"github_user": maintainer}]
260
261    # TODO derive better default values:
262    # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ...
263    # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ...
264    # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ...
265    kwargs = {
266        "name": _get_kwarg("name", name, lambda: trainer.name),
267        "description": _get_kwarg("description", description, lambda: trainer.name),
268        "authors": _get_kwarg("authors", authors, _default_authors, is_list=True),
269        "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True),
270        "license": _get_kwarg("license", license, lambda: "MIT"),
271        "documentation": _get_kwarg(
272            "documentation", documentation, lambda: trainer.name, fname="documentation.md"
273        ),
274        "git_repo": _get_kwarg("git_repo", git_repo, _default_repo),
275        "cite": _get_kwarg("cite", cite, get_default_citations),
276        "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True),
277    }
278
279    return kwargs
280
281
282def _get_preprocessing(trainer):
283    try:
284        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
285            ndim = trainer.train_loader.dataset.dataset.ndim
286        else:
287            ndim = trainer.train_loader.dataset.ndim
288    except AttributeError:
289        ndim = trainer.train_loader.dataset.datasets[0].ndim
290    normalizer = get_normalizer(trainer)
291
292    if isinstance(normalizer, functools.partial):
293        kwargs = normalizer.keywords
294        normalizer = normalizer.func
295    else:
296        kwargs = {}
297
298    def _get_axes(axis):
299        all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"]
300        if axis is None:
301            axes = all_axes
302        else:
303            axes = [all_axes[i] for i in axes]
304        return axes
305
306    name = f"{normalizer.__module__}.{normalizer.__name__}"
307    axes = _get_axes(kwargs.get("axis", None))
308
309    if name == "torch_em.transform.raw.normalize":
310
311        min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None)
312        assert (min_ is None) == (max_ is None)
313
314        if min_ is None:
315            spec_name = "scale_range",
316            spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0}
317        else:
318            spec_name = "scale_linear"
319            spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes}
320
321    elif name == "torch_em.transform.raw.standardize":
322        spec_kwargs = {"axes": axes}
323        mean, std = kwargs.get("mean", None), kwargs.get("std", None)
324        if (mean is None) and (std is None):
325            spec_name = "zero_mean_unit_variance"
326        else:
327            spec_name = "fixed_zero_mean_unit_varaince"
328            spec_kwargs.update({"mean": mean, "std": std})
329
330    elif name == "torch_em.transform.raw.normalize_percentile":
331        lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0)
332        spec_name = "scale_range"
333        spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper}
334
335    else:
336        warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.")
337        return None
338
339    name_to_cls = {
340        "scale_linear": spec.ScaleLinearDescr,
341        "scale_rage": spec.ScaleRangeDescr,
342        "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr,
343        "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr,
344    }
345    preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs)
346
347    return [preprocessing]
348
349
350def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo):
351
352    notebook_link = None
353    module = str(model.__class__.__module__)
354    name = str(model.__class__.__name__)
355
356    # can derive tensor kwargs only for known torch_em models (only unet for now)
357    if module == "torch_em.model.unet":
358        assert len(input_tensors) == len(output_tensors) == 1
359        inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"]
360
361        postprocessing = model_kwargs.get("postprocessing", None)
362        if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"):
363            outc = 1
364        elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"):
365            outc = 2
366        elif postprocessing is not None:
367            warn(f"The model has the post-processing {postprocessing} which cannot be interpreted")
368
369        if name == "UNet2d":
370            depth = model_kwargs["depth"]
371            step = [2 ** depth] * 2
372            if min_shape is None:
373                min_shape = [2 ** (depth + 1)] * 2
374            else:
375                assert len(min_shape) == 2
376                min_shape = list(min_shape)
377            notebook_link = "ilastik/torch-em-2d-unet-notebook"
378
379        elif name == "UNet3d":
380            depth = model_kwargs["depth"]
381            step = [2 ** depth] * 3
382            if min_shape is None:
383                min_shape = [2 ** (depth + 1)] * 3
384            else:
385                assert len(min_shape) == 3
386                min_shape = list(min_shape)
387            notebook_link = "ilastik/torch-em-3d-unet-notebook"
388
389        elif name == "AnisotropicUNet":
390            scale_factors = model_kwargs["scale_factors"]
391            scale_prod = [
392                int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))]))
393                for d in range(3)
394            ]
395            assert len(scale_prod) == 3
396            step = scale_prod
397            if min_shape is None:
398                min_shape = [2 * sp for sp in scale_prod]
399            else:
400                min_shape = list(min_shape)
401            notebook_link = "ilastik/torch-em-3d-unet-notebook"
402
403        else:
404            raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}")
405
406        if halo is None:   # default halo = step // 2
407            halo = [st // 2 for st in step]
408        else:  # make sure the passed halo has the same length as step, by padding with zeros
409            halo = [0] * (len(step) - len(halo)) + halo
410        assert len(halo) == len(step), f"{len(halo)}, {len(step)}"
411
412        # Create the input axis description.
413        input_axes = [
414            spec.BatchAxis(),
415            spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]),
416        ]
417        input_ndim = np.load(input_tensors[0]).ndim
418        assert input_ndim in (4, 5)
419        axis_names = "zyx" if input_ndim == 5 else "yx"
420        assert len(axis_names) == len(min_shape) == len(step)
421        input_axes += [
422            spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step))
423            for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step)
424        ]
425
426        # Create the rest of the input description.
427        preprocessing = _get_preprocessing(trainer)
428        input_description = [spec.InputTensorDescr(
429            id=spec.TensorId("image"),
430            axes=input_axes,
431            test_tensor=spec.FileDescr(source=Path(input_tensors[0])),
432            preprocessing=preprocessing,
433        )]
434
435        # Create the output axis description.
436        output_axes = [
437            spec.BatchAxis(),
438            spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]),
439        ]
440        output_ndim = np.load(output_tensors[0]).ndim
441        assert output_ndim in (4, 5)
442        axis_names = "zyx" if output_ndim == 5 else "yx"
443        assert len(axis_names) == len(halo)
444        output_axes += [
445            spec.SpaceOutputAxisWithHalo(
446                id=spec.AxisId(ax_name),
447                size=spec.SizeReference(
448                    tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name)
449                ),
450                halo=halo_val,
451            ) for ax_name, halo_val in zip(axis_names, halo)
452        ]
453
454        # Create the rest of the output description.
455        output_description = [spec.OutputTensorDescr(
456            id=spec.TensorId("prediction"),
457            axes=output_axes,
458            test_tensor=spec.FileDescr(source=Path(output_tensors[0]))
459        )]
460
461    else:
462        raise NotImplementedError("Model export currently only works for torch_em.model.unet.")
463
464    return input_description, output_description, notebook_link
465
466
467def _validate_model(spec_path, output_path):
468    if not os.path.exists(spec_path):
469        return False
470
471    try:
472        model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True, output_path=output_path)
473        root = output_path
474
475        input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs]
476        inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths]
477
478        expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs]
479        expected = [np.load(opt) for opt in expected_paths]
480
481        with torch.no_grad():
482            inputs = [torch.from_numpy(input_) for input_ in inputs]
483            outputs = model(*inputs)
484            if torch.is_tensor(outputs):
485                outputs = [outputs]
486            outputs = [out.numpy() for out in outputs]
487
488        for out, exp in zip(outputs, expected):
489            if not np.allclose(out, exp):
490                return False
491
492    except Exception as e:
493        print("Model validation failed with the following exception:")
494        print(e)
495        return False
496
497    return True
498
499
500#
501# Model Export Functionality
502#
503
504def _get_input_data(trainer):
505    loader = trainer.val_loader
506    x = next(iter(loader))[0].numpy()
507    return x
508
509
510def export_bioimageio_model(
511    checkpoint: str,
512    output_path: str,
513    input_data: Optional[np.ndarray] = None,
514    name: Optional[str] = None,
515    description: Optional[str] = None,
516    authors: Optional[List[Dict[str, str]]] = None,
517    tags: Optional[List[str]] = None,
518    license: Optional[str] = None,
519    documentation: Optional[str] = None,
520    covers: Optional[str] = None,
521    git_repo: Optional[str] = None,
522    cite: Optional[List[Dict[str, str]]] = None,
523    input_optional_parameters: bool = True,
524    model_postprocessing: Optional[str] = None,
525    for_deepimagej: bool = False,
526    links: Optional[List[str]] = None,
527    maintainers: Optional[List[Dict[str, str]]] = None,
528    min_shape: Tuple[int, ...] = None,
529    halo: Tuple[int, ...] = None,
530    checkpoint_name: str = "best",
531    config: Dict = {},
532) -> bool:
533    """Export model to bioimage.io model format.
534
535    Args:
536        checkpoint: The path to the checkpoint with the model to export.
537        output_path: The output path for saving the model.
538        input_data: The input data for creating model test data.
539        name: The export name of the model.
540        description: The description of the model.
541        authors: The authors that created this model.
542        tags: List of tags for this model.
543        license: The license under which to publish the model.
544        documentation: The documentation of the model.
545        covers: The covers to show when displaying the model.
546        git_repo: A github repository associated with this model.
547        cite: References to cite for this model.
548        input_optional_parameters: Whether to input optional parameters via the command line.
549        model_postprocessing: Postprocessing to apply to the model outputs.
550        for_deepimagej: Whether this model can be run in DeepImageJ.
551        links: Linked modelzoo apps or software for this model.
552        maintainers: The maintainers of this model.
553        min_shape: The minimal valid input shape for the model.
554        halo: The halo to cut away from model outputs.
555        checkpoint_name: The name of the model checkpoint to load for the export.
556        config: Dictionary with additional configuration for this model.
557
558    Returns:
559        Whether the exported model was successfully validated.
560    """
561    # Load the trainer and model.
562    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
563    model, model_kwargs = _get_model(trainer, model_postprocessing)
564
565    # Get input data from the trainer if it is not given.
566    if input_data is None:
567        input_data = _get_input_data(trainer)
568
569    with tempfile.TemporaryDirectory() as export_folder:
570
571        # Create the weight description.
572        weight_description = _create_weight_description(model, export_folder, model_kwargs)
573
574        # Create the test input/output files.
575        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
576        # Get the descriptions for inputs, outputs and notebook links.
577        input_description, output_description, notebook_link = _get_inout_descriptions(
578            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
579        )
580
581        # Get the additional kwargs.
582        kwargs = _get_kwargs(
583            trainer, name, description,
584            authors, tags, license, documentation,
585            git_repo, cite, maintainers,
586            export_folder, input_optional_parameters
587        )
588
589        # TODO double check the current link policy
590        # The apps to link with this model, by default ilastik.
591        if links is None:
592            links = []
593        links.append("ilastik/ilastik")
594        # add the notebook link, if available
595        if notebook_link is not None:
596            links.append(notebook_link)
597        kwargs.update({"links": links})
598
599        if covers is not None:
600            kwargs["covers"] = covers
601
602        model_description = spec.ModelDescr(
603            inputs=input_description,
604            outputs=output_description,
605            weights=weight_description,
606            config=config,
607            **kwargs,
608        )
609
610        save_bioimageio_package(model_description, output_path=output_path)
611
612    # Validate the model.
613    with tempfile.TemporaryDirectory() as tmp_dir:
614        val_success = _validate_model(output_path, tmp_dir)
615    if val_success:
616        print(f"The model was successfully exported to '{output_path}'.")
617    else:
618        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
619             "You can use this model, but it will probably yield incorrect results.")
620    return val_success
621
622
623# TODO support bounding boxes
624def _load_data(path, key):
625    if key is None:
626        ext = os.path.splitext(path)[-1]
627        if ext == ".npy":
628            return np.load(path)
629        else:
630            return imageio.imread(path)
631    else:
632        return open_file(path, mode="r")[key][:]
633
634
635def main():
636    """@private
637    """
638    parser = argparse.ArgumentParser(
639        "Export model trained with torch_em to the BioImage.IO model format."
640        "The exported model can be run in any tool supporting BioImage.IO."
641        "For details check out https://bioimage.io/#/."
642    )
643    parser.add_argument("-p", "--path", required=True,
644                        help="Path to the model checkpoint to export to the BioImage.IO model format.")
645    parser.add_argument("-d", "--data", required=True,
646                        help="Path to the test data to use for creating the exported model.")
647    parser.add_argument("-f", "--export_folder", required=True,
648                        help="Where to save the exported model. The exported model is stored as a zip in the folder.")
649    parser.add_argument("-k", "--key",
650                        help="The key for the test data. Required for container data formats like hdf5 or zarr.")
651    parser.add_argument("-n", "--name", help="The name of the exported model.")
652
653    args = parser.parse_args()
654    name = os.path.basename(args.path) if args.name is None else args.name
655    export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name)
656
657
658#
659# model import functionality
660#
661
662
663def _unzip_and_load_model(model_spec, device, spec_path, output_path):
664    unpack_archive(str(spec_path), str(output_path))
665    weight_spec = model_spec.weights.pytorch_state_dict
666    model = load_torch_model(weight_spec, devices=[device])
667    model.eval()
668    return model
669
670
671def _load_model(model_spec, device, spec_path, output_path):
672    if output_path is None:
673        with tempfile.TemporaryDirectory() as tmp_dir:
674            return _unzip_and_load_model(model_spec, device, spec_path, tmp_dir)
675    else:
676        return _unzip_and_load_model(model_spec, device, spec_path, output_path)
677
678
679def _load_normalizer(model_spec):
680    inputs = model_spec.inputs[0]
681    preprocessing = inputs.preprocessing
682
683    # Filter out ensure dtype.
684    preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"]
685    if len(preprocessing) == 0:
686        return None
687
688    ndim = len(inputs.axes) - 2
689    shape = inputs.shape
690    if hasattr(shape, "min"):
691        shape = shape.min
692
693    conf = preprocessing[0]
694    name = conf.id
695    spec_kwargs = conf.kwargs
696
697    def _get_axis(axes):
698        label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\
699            {"channel": 0, "y": 1, "x": 2}
700        axis = tuple(label_to_id[ax] for ax in axes)
701
702        # Is the axis full? Then we don't need to specify it.
703        if len(axis) == ndim + 1:
704            return None
705
706        # Drop the channel axis if we have only a single channel.
707        # Because torch_em squeezes the channel axis in this case.
708        if shape[1] == 1:
709            axis = tuple(ax - 1 for ax in axis if ax > 0)
710        return axis
711
712    axis = _get_axis(spec_kwargs.get("axes", None))
713    if name == "zero_mean_unit_variance":
714        kwargs = {"axis": axis}
715        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
716
717    elif name == "fixed_zero_mean_unit_variance":
718        kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]}
719        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
720
721    elif name == "scale_linear":
722        min_ = -spec_kwargs["offset"]
723        max_ = 1. / spec_kwargs["gain"]
724        kwargs = {"axis": axis, "minval": min_, "maxval": max_}
725        normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs)
726
727    elif name == "scale_range":
728        assert spec_kwargs["mode"] == "per_sample"  # Can't parse the other modes right now.
729        lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"]
730        if np.isclose(lower, 0.0) and np.isclose(upper, 100.0):
731            normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis)
732        else:
733            kwargs = {"axis": axis, "lower": lower, "upper": upper}
734            normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs)
735
736    else:
737        msg = f"torch_em does not support the use of the biomageio preprocessing function {name}."
738        raise RuntimeError(msg)
739
740    return normalizer
741
742
743def import_bioimageio_model(
744    spec_path: str,
745    return_spec: bool = False,
746    device: Optional[str] = None,
747    output_path: Optional[str] = None,
748) -> Tuple[torch.nn.Module, Callable]:
749    """Import a pytorch model from a bioimageio model.
750
751    Args:
752        spec_path: The path to the bioimageio model. Expects a zipped model file.
753        return_spec: Whether to return the deserialized model description in additon to the model.
754        device: The device to use for loading the model.
755        output_path: The output path for deserializing model files.
756            By default the files will be deserialized to a temporary path.
757
758    Returns:
759        The model loaded as torch.nn.Module.
760        The preprocessing function.
761    """
762    model_spec = core.load_description(spec_path)
763
764    device = "cuda" if torch.cuda.is_available() else "cpu"
765    model = _load_model(model_spec, device=device, spec_path=spec_path, output_path=output_path)
766    normalizer = _load_normalizer(model_spec)
767
768    if return_spec:
769        return model, normalizer, model_spec
770    else:
771        return model, normalizer
772
773
774# TODO: the weight conversion needs to be updated once the
775# corresponding functionality in bioimageio.core is updated
776#
777# Weight Conversion
778#
779
780
781def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs):
782    with tempfile.TemporaryDirectory() as tmp_dir:
783        weight_path = os.path.join(tmp_dir, weight_name)
784        model_spec = core.load_description(spec_path)
785        weight_descr = converter(model_spec, weight_path, **kwargs)
786        # TODO double check
787        setattr(model_spec.weights, weight_type, weight_descr)
788        save_bioimageio_package(model_spec, output_path=spec_path)
789
790
791def convert_to_onnx(spec_path, opset_version=12):
792    """@private
793    """
794    raise NotImplementedError
795    # converter = weight_converter.convert_weights_to_onnx
796    # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version)
797    # return None
798
799
800def convert_to_torchscript(model_path):
801    """@private
802    """
803    raise NotImplementedError
804    # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript
805
806    # weight_name = "weights-torchscript.pt"
807    # breakpoint()
808    # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript")
809
810    # # Check that we can load the converted weights.
811    # model_spec = core.load_description(model_path)
812    # weight_path = model_spec.weights.torchscript.weights
813    # try:
814    #     torch.jit.load(weight_path)
815    #     return None
816    # except Exception as e:
817    #     return e
818
819
820def add_weight_formats(model_path, additional_formats):
821    """@private
822    """
823    for add_format in additional_formats:
824
825        if add_format == "onnx":
826            ret = convert_to_onnx(model_path)
827        elif add_format == "torchscript":
828            ret = convert_to_torchscript(model_path)
829
830        if ret is None:
831            print("Successfully added", add_format, "weights")
832        else:
833            warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.")
834
835
836def convert_main():
837    """@private
838    """
839    parser = argparse.ArgumentParser("Convert weights from native pytorch format to onnx or torchscript")
840    parser.add_argument("-f", "--model_folder", required=True, help="")
841    parser.add_argument("-w", "--weight_format", required=True, help="")
842    args = parser.parse_args()
843    weight_format = args.weight_format
844    assert weight_format in ("onnx", "torchscript")
845    if weight_format == "onnx":
846        convert_to_onnx(args.model_folder)
847    else:
848        convert_to_torchscript(args.model_folder)
849
850
851#
852# Misc Functionality
853#
854
855def export_parser_helper():
856    """@private
857    """
858    parser = argparse.ArgumentParser()
859    parser.add_argument("-c", "--checkpoint", required=True)
860    parser.add_argument("-i", "--input", required=True)
861    parser.add_argument("-o", "--output", required=True)
862    parser.add_argument("-a", "--affs_to_bd", default=0, type=int)
863    parser.add_argument("-f", "--additional_formats", type=str, nargs="+")
864    return parser
865
866
867def get_mws_config(offsets, config=None):
868    """@private
869    """
870    mws_config = {"offsets": offsets}
871    if config is None:
872        config = {"mws": mws_config}
873    else:
874        assert isinstance(config, dict)
875        config["mws"] = mws_config
876    return config
877
878
879def get_shallow2deep_config(rf_path, config=None):
880    """@private
881    """
882    if os.path.isdir(rf_path):
883        rf_path = glob(os.path.join(rf_path, "*.pkl"))[0]
884    assert os.path.exists(rf_path), rf_path
885    with open(rf_path, "rb") as f:
886        rf = pickle.load(f)
887    shallow2deep_config = {"ndim": rf.feature_ndim, "features": rf.feature_config}
888    if config is None:
889        config = {"shallow2deep": shallow2deep_config}
890    else:
891        assert isinstance(config, dict)
892        config["shallow2deep"] = shallow2deep_config
893    return config
def export_bioimageio_model( checkpoint: str, output_path: str, input_data: Optional[numpy.ndarray] = None, name: Optional[str] = None, description: Optional[str] = None, authors: Optional[List[Dict[str, str]]] = None, tags: Optional[List[str]] = None, license: Optional[str] = None, documentation: Optional[str] = None, covers: Optional[str] = None, git_repo: Optional[str] = None, cite: Optional[List[Dict[str, str]]] = None, input_optional_parameters: bool = True, model_postprocessing: Optional[str] = None, for_deepimagej: bool = False, links: Optional[List[str]] = None, maintainers: Optional[List[Dict[str, str]]] = None, min_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None, checkpoint_name: str = 'best', config: Dict = {}) -> bool:
511def export_bioimageio_model(
512    checkpoint: str,
513    output_path: str,
514    input_data: Optional[np.ndarray] = None,
515    name: Optional[str] = None,
516    description: Optional[str] = None,
517    authors: Optional[List[Dict[str, str]]] = None,
518    tags: Optional[List[str]] = None,
519    license: Optional[str] = None,
520    documentation: Optional[str] = None,
521    covers: Optional[str] = None,
522    git_repo: Optional[str] = None,
523    cite: Optional[List[Dict[str, str]]] = None,
524    input_optional_parameters: bool = True,
525    model_postprocessing: Optional[str] = None,
526    for_deepimagej: bool = False,
527    links: Optional[List[str]] = None,
528    maintainers: Optional[List[Dict[str, str]]] = None,
529    min_shape: Tuple[int, ...] = None,
530    halo: Tuple[int, ...] = None,
531    checkpoint_name: str = "best",
532    config: Dict = {},
533) -> bool:
534    """Export model to bioimage.io model format.
535
536    Args:
537        checkpoint: The path to the checkpoint with the model to export.
538        output_path: The output path for saving the model.
539        input_data: The input data for creating model test data.
540        name: The export name of the model.
541        description: The description of the model.
542        authors: The authors that created this model.
543        tags: List of tags for this model.
544        license: The license under which to publish the model.
545        documentation: The documentation of the model.
546        covers: The covers to show when displaying the model.
547        git_repo: A github repository associated with this model.
548        cite: References to cite for this model.
549        input_optional_parameters: Whether to input optional parameters via the command line.
550        model_postprocessing: Postprocessing to apply to the model outputs.
551        for_deepimagej: Whether this model can be run in DeepImageJ.
552        links: Linked modelzoo apps or software for this model.
553        maintainers: The maintainers of this model.
554        min_shape: The minimal valid input shape for the model.
555        halo: The halo to cut away from model outputs.
556        checkpoint_name: The name of the model checkpoint to load for the export.
557        config: Dictionary with additional configuration for this model.
558
559    Returns:
560        Whether the exported model was successfully validated.
561    """
562    # Load the trainer and model.
563    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
564    model, model_kwargs = _get_model(trainer, model_postprocessing)
565
566    # Get input data from the trainer if it is not given.
567    if input_data is None:
568        input_data = _get_input_data(trainer)
569
570    with tempfile.TemporaryDirectory() as export_folder:
571
572        # Create the weight description.
573        weight_description = _create_weight_description(model, export_folder, model_kwargs)
574
575        # Create the test input/output files.
576        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
577        # Get the descriptions for inputs, outputs and notebook links.
578        input_description, output_description, notebook_link = _get_inout_descriptions(
579            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
580        )
581
582        # Get the additional kwargs.
583        kwargs = _get_kwargs(
584            trainer, name, description,
585            authors, tags, license, documentation,
586            git_repo, cite, maintainers,
587            export_folder, input_optional_parameters
588        )
589
590        # TODO double check the current link policy
591        # The apps to link with this model, by default ilastik.
592        if links is None:
593            links = []
594        links.append("ilastik/ilastik")
595        # add the notebook link, if available
596        if notebook_link is not None:
597            links.append(notebook_link)
598        kwargs.update({"links": links})
599
600        if covers is not None:
601            kwargs["covers"] = covers
602
603        model_description = spec.ModelDescr(
604            inputs=input_description,
605            outputs=output_description,
606            weights=weight_description,
607            config=config,
608            **kwargs,
609        )
610
611        save_bioimageio_package(model_description, output_path=output_path)
612
613    # Validate the model.
614    with tempfile.TemporaryDirectory() as tmp_dir:
615        val_success = _validate_model(output_path, tmp_dir)
616    if val_success:
617        print(f"The model was successfully exported to '{output_path}'.")
618    else:
619        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
620             "You can use this model, but it will probably yield incorrect results.")
621    return val_success

Export model to bioimage.io model format.

Arguments:
  • checkpoint: The path to the checkpoint with the model to export.
  • output_path: The output path for saving the model.
  • input_data: The input data for creating model test data.
  • name: The export name of the model.
  • description: The description of the model.
  • authors: The authors that created this model.
  • tags: List of tags for this model.
  • license: The license under which to publish the model.
  • documentation: The documentation of the model.
  • covers: The covers to show when displaying the model.
  • git_repo: A github repository associated with this model.
  • cite: References to cite for this model.
  • input_optional_parameters: Whether to input optional parameters via the command line.
  • model_postprocessing: Postprocessing to apply to the model outputs.
  • for_deepimagej: Whether this model can be run in DeepImageJ.
  • links: Linked modelzoo apps or software for this model.
  • maintainers: The maintainers of this model.
  • min_shape: The minimal valid input shape for the model.
  • halo: The halo to cut away from model outputs.
  • checkpoint_name: The name of the model checkpoint to load for the export.
  • config: Dictionary with additional configuration for this model.
Returns:

Whether the exported model was successfully validated.

def import_bioimageio_model( spec_path: str, return_spec: bool = False, device: Optional[str] = None, output_path: Optional[str] = None) -> Tuple[torch.nn.modules.module.Module, Callable]:
744def import_bioimageio_model(
745    spec_path: str,
746    return_spec: bool = False,
747    device: Optional[str] = None,
748    output_path: Optional[str] = None,
749) -> Tuple[torch.nn.Module, Callable]:
750    """Import a pytorch model from a bioimageio model.
751
752    Args:
753        spec_path: The path to the bioimageio model. Expects a zipped model file.
754        return_spec: Whether to return the deserialized model description in additon to the model.
755        device: The device to use for loading the model.
756        output_path: The output path for deserializing model files.
757            By default the files will be deserialized to a temporary path.
758
759    Returns:
760        The model loaded as torch.nn.Module.
761        The preprocessing function.
762    """
763    model_spec = core.load_description(spec_path)
764
765    device = "cuda" if torch.cuda.is_available() else "cpu"
766    model = _load_model(model_spec, device=device, spec_path=spec_path, output_path=output_path)
767    normalizer = _load_normalizer(model_spec)
768
769    if return_spec:
770        return model, normalizer, model_spec
771    else:
772        return model, normalizer

Import a pytorch model from a bioimageio model.

Arguments:
  • spec_path: The path to the bioimageio model. Expects a zipped model file.
  • return_spec: Whether to return the deserialized model description in additon to the model.
  • device: The device to use for loading the model.
  • output_path: The output path for deserializing model files. By default the files will be deserialized to a temporary path.
Returns:

The model loaded as torch.nn.Module. The preprocessing function.