torch_em.util.modelzoo

  1import argparse
  2import functools
  3import json
  4import os
  5import pickle
  6import subprocess
  7import tempfile
  8
  9from glob import glob
 10from pathlib import Path
 11from typing import Dict, List, Optional, Tuple
 12from warnings import warn
 13
 14import imageio
 15import numpy as np
 16import torch
 17import torch_em
 18
 19import bioimageio.core as core
 20import bioimageio.spec.model.v0_5 as spec
 21from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
 22from bioimageio.spec import save_bioimageio_package
 23
 24from elf.io import open_file
 25from .util import get_trainer, get_normalizer
 26
 27
 28#
 29# General Purpose Functionality
 30#
 31
 32
 33def normalize_with_batch(data, normalizer):
 34    """@private
 35    """
 36    if normalizer is None:
 37        return data
 38    normalized = np.concatenate([normalizer(da)[None] for da in data], axis=0)
 39    return normalized
 40
 41
 42#
 43# Utility Functions for Model Export.
 44#
 45
 46
 47def get_default_citations(model=None, model_output=None):
 48    """@private
 49    """
 50    citations = [{"text": "training library", "doi": "10.5281/zenodo.5108853"}]
 51
 52    # try to derive the correct network citation from the model class
 53    if model is not None:
 54        if isinstance(model, str):
 55            model_name = model
 56        else:
 57            model_name = str(model.__class__.__name__)
 58
 59        if model_name.lower() in ("unet2d", "unet_2d", "unet"):
 60            citations.append(
 61                {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"}
 62            )
 63        elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"):
 64            citations.append(
 65                {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"}
 66            )
 67        else:
 68            warn("No citation for architecture {model_name} found.")
 69
 70    # try to derive the correct segmentation algo citation from the model output type
 71    if model_output is not None:
 72        msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported."
 73        if model_output == "affinities":
 74            citations.append(
 75                {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"}
 76            )
 77        elif model_output == "boundaries":
 78            citations.append(
 79                {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"}
 80            )
 81        else:
 82            warn(msg)
 83
 84    return citations
 85
 86
 87def _get_model(trainer, postprocessing):
 88    model = trainer.model
 89    model.eval()
 90    model_kwargs = model.init_kwargs
 91    # clear the kwargs of non builtins
 92    # TODO warn if we strip any non-standard arguments
 93    model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)}
 94
 95    # set the in-model postprocessing if given
 96    if postprocessing is not None:
 97        assert "postprocessing" in model_kwargs
 98        model_kwargs["postprocessing"] = postprocessing
 99        state = model.state_dict()
100        model = model.__class__(**model_kwargs)
101        model.load_state_dict(state)
102        model.eval()
103
104    return model, model_kwargs
105
106
107def _pad(input_data, trainer):
108    try:
109        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
110            ndim = trainer.train_loader.dataset.dataset.ndim
111        else:
112            ndim = trainer.train_loader.dataset.ndim
113    except AttributeError:
114        ndim = trainer.train_loader.dataset.datasets[0].ndim
115    target_dims = ndim + 2
116    for _ in range(target_dims - input_data.ndim):
117        input_data = np.expand_dims(input_data, axis=0)
118    return input_data
119
120
121def _write_data(input_data, model, trainer, export_folder):
122    # if input_data is None:
123    #     gen = SampleGenerator(trainer, 1, False, 1)
124    #     input_data = next(gen)
125    if isinstance(input_data, np.ndarray):
126        input_data = [input_data]
127
128    # normalize the input data if we have a normalization function
129    normalizer = get_normalizer(trainer)
130
131    # pad to 4d/5d and normalize the input data
132    # NOTE we have to save the padded data, but without normalization
133    test_inputs = [_pad(input_, trainer) for input_ in input_data]
134    normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs]
135
136    # run prediction
137    with torch.no_grad():
138        test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized]
139        test_outputs = model(*test_tensors)
140        if torch.is_tensor(test_outputs):
141            test_outputs = [test_outputs]
142        test_outputs = [out.cpu().numpy() for out in test_outputs]
143
144    # save the input / output
145    test_in_paths, test_out_paths = [], []
146    for i, input_ in enumerate(test_inputs):
147        test_in_path = os.path.join(export_folder, f"test_input_{i}.npy")
148        np.save(test_in_path, input_)
149        test_in_paths.append(test_in_path)
150    for i, out in enumerate(test_outputs):
151        test_out_path = os.path.join(export_folder, f"test_output_{i}.npy")
152        np.save(test_out_path, out)
153        test_out_paths.append(test_out_path)
154    return test_in_paths, test_out_paths
155
156
157def _create_weight_description(model, export_folder, model_kwargs):
158    module = str(model.__class__.__module__)
159    cls_name = str(model.__class__.__name__)
160
161    if module == "torch_em.model.unet":
162        source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py")
163        architecture = spec.ArchitectureFromFileDescr(
164            source=Path(source_path),
165            callable=cls_name,
166            kwargs=model_kwargs,
167        )
168    else:
169        architecture = spec.ArchitectureFromLibraryDescr(
170            import_from=module,
171            callable=cls_name,
172            kwargs=model_kwargs,
173        )
174
175    checkpoint_path = os.path.join(export_folder, "state_dict.pt")
176    torch.save(model.state_dict(), checkpoint_path)
177
178    weight_description = spec.WeightsDescr(
179        pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
180            source=Path(checkpoint_path),
181            architecture=architecture,
182            pytorch_version=spec.Version(torch.__version__),
183        )
184    )
185    return weight_description
186
187
188def _get_kwargs(
189    trainer, name, description, authors, tags, license, documentation,
190    git_repo, cite, maintainers, export_folder, input_optional_parameters
191):
192    if input_optional_parameters:
193        print("Enter values for the optional parameters.")
194        print("If the default value in [] is satisfactory, press enter without additional input.")
195        print("Please enter lists using json syntax.")
196
197    def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None):
198        # if we don"t have a value, we either ask user for input (offering the default)
199        # or just use the default if input_optional_parameters is False
200        if val is None and input_optional_parameters:
201            default_val = default()
202            choice = input(f"{kwarg_name} [{default_val}]: ")
203            val = choice if choice else default_val
204        elif val is None:
205            val = default()
206
207        if fname is not None:
208            save_path = os.path.join(export_folder, fname)
209            with open(save_path, "w") as f:
210                f.write(val)
211            return save_path
212
213        if is_list and isinstance(val, str):
214            val = val.replace("'", '"')  # enable single quotes
215            val = json.loads(val)
216        if is_list:
217            assert isinstance(val, (list, tuple)), type(val)
218        return val
219
220    def _default_authors():
221        # first try to derive the author name from git
222        try:
223            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
224            author = call_res.stdout.decode("utf8").rstrip("\n")
225            author = author if author else None  # in case there was no error, but output is empty
226        except Exception:
227            author = None
228
229        # otherwise use the username
230        if author is None:
231            author = os.uname()[1]
232
233        return [{"name": author}]
234
235    def _default_repo():
236        return None
237        try:
238            call_res = subprocess.run(["git", "remote", "-v"], capture_output=True)
239            repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1]
240            repo = repo if repo else None
241        except Exception:
242            repo = None
243        return repo
244
245    def _default_maintainers():
246        # first try to derive the maintainer name from git
247        try:
248            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
249            maintainer = call_res.stdout.decode("utf8").rstrip("\n")
250            maintainer = maintainer if maintainer else None  # in case there was no error, but output is empty
251        except Exception:
252            maintainer = None
253
254        # otherwise use the username
255        if maintainer is None:
256            maintainer = os.uname()[1]
257
258        return [{"github_user": maintainer}]
259
260    # TODO derive better default values:
261    # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ...
262    # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ...
263    # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ...
264    kwargs = {
265        "name": _get_kwarg("name", name, lambda: trainer.name),
266        "description": _get_kwarg("description", description, lambda: trainer.name),
267        "authors": _get_kwarg("authors", authors, _default_authors, is_list=True),
268        "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True),
269        "license": _get_kwarg("license", license, lambda: "MIT"),
270        "documentation": _get_kwarg(
271            "documentation", documentation, lambda: trainer.name, fname="documentation.md"
272        ),
273        "git_repo": _get_kwarg("git_repo", git_repo, _default_repo),
274        "cite": _get_kwarg("cite", cite, get_default_citations),
275        "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True),
276    }
277
278    return kwargs
279
280
281def _get_preprocessing(trainer):
282    try:
283        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
284            ndim = trainer.train_loader.dataset.dataset.ndim
285        else:
286            ndim = trainer.train_loader.dataset.ndim
287    except AttributeError:
288        ndim = trainer.train_loader.dataset.datasets[0].ndim
289    normalizer = get_normalizer(trainer)
290
291    if isinstance(normalizer, functools.partial):
292        kwargs = normalizer.keywords
293        normalizer = normalizer.func
294    else:
295        kwargs = {}
296
297    def _get_axes(axis):
298        all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"]
299        if axis is None:
300            axes = all_axes
301        else:
302            axes = [all_axes[i] for i in axes]
303        return axes
304
305    name = f"{normalizer.__module__}.{normalizer.__name__}"
306    axes = _get_axes(kwargs.get("axis", None))
307
308    if name == "torch_em.transform.raw.normalize":
309
310        min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None)
311        assert (min_ is None) == (max_ is None)
312
313        if min_ is None:
314            spec_name = "scale_range",
315            spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0}
316        else:
317            spec_name = "scale_linear"
318            spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes}
319
320    elif name == "torch_em.transform.raw.standardize":
321        spec_kwargs = {"axes": axes}
322        mean, std = kwargs.get("mean", None), kwargs.get("std", None)
323        if (mean is None) and (std is None):
324            spec_name = "zero_mean_unit_variance"
325        else:
326            spec_name = "fixed_zero_mean_unit_varaince"
327            spec_kwargs.update({"mean": mean, "std": std})
328
329    elif name == "torch_em.transform.raw.normalize_percentile":
330        lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0)
331        spec_name = "scale_range"
332        spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper}
333
334    else:
335        warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.")
336        return None
337
338    name_to_cls = {
339        "scale_linear": spec.ScaleLinearDescr,
340        "scale_rage": spec.ScaleRangeDescr,
341        "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr,
342        "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr,
343    }
344    preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs)
345
346    return [preprocessing]
347
348
349def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo):
350
351    notebook_link = None
352    module = str(model.__class__.__module__)
353    name = str(model.__class__.__name__)
354
355    # can derive tensor kwargs only for known torch_em models (only unet for now)
356    if module == "torch_em.model.unet":
357        assert len(input_tensors) == len(output_tensors) == 1
358        inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"]
359
360        postprocessing = model_kwargs.get("postprocessing", None)
361        if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"):
362            outc = 1
363        elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"):
364            outc = 2
365        elif postprocessing is not None:
366            warn(f"The model has the post-processing {postprocessing} which cannot be interpreted")
367
368        if name == "UNet2d":
369            depth = model_kwargs["depth"]
370            step = [2 ** depth] * 2
371            if min_shape is None:
372                min_shape = [2 ** (depth + 1)] * 2
373            else:
374                assert len(min_shape) == 2
375                min_shape = list(min_shape)
376            notebook_link = "ilastik/torch-em-2d-unet-notebook"
377
378        elif name == "UNet3d":
379            depth = model_kwargs["depth"]
380            step = [2 ** depth] * 3
381            if min_shape is None:
382                min_shape = [2 ** (depth + 1)] * 3
383            else:
384                assert len(min_shape) == 3
385                min_shape = list(min_shape)
386            notebook_link = "ilastik/torch-em-3d-unet-notebook"
387
388        elif name == "AnisotropicUNet":
389            scale_factors = model_kwargs["scale_factors"]
390            scale_prod = [
391                int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))]))
392                for d in range(3)
393            ]
394            assert len(scale_prod) == 3
395            step = scale_prod
396            if min_shape is None:
397                min_shape = [2 * sp for sp in scale_prod]
398            else:
399                min_shape = list(min_shape)
400            notebook_link = "ilastik/torch-em-3d-unet-notebook"
401
402        else:
403            raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}")
404
405        if halo is None:   # default halo = step // 2
406            halo = [st // 2 for st in step]
407        else:  # make sure the passed halo has the same length as step, by padding with zeros
408            halo = [0] * (len(step) - len(halo)) + halo
409        assert len(halo) == len(step), f"{len(halo)}, {len(step)}"
410
411        # Create the input axis description.
412        input_axes = [
413            spec.BatchAxis(),
414            spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]),
415        ]
416        input_ndim = np.load(input_tensors[0]).ndim
417        assert input_ndim in (4, 5)
418        axis_names = "zyx" if input_ndim == 5 else "yx"
419        assert len(axis_names) == len(min_shape) == len(step)
420        input_axes += [
421            spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step))
422            for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step)
423        ]
424
425        # Create the rest of the input description.
426        preprocessing = _get_preprocessing(trainer)
427        input_description = [spec.InputTensorDescr(
428            id=spec.TensorId("image"),
429            axes=input_axes,
430            test_tensor=spec.FileDescr(source=Path(input_tensors[0])),
431            preprocessing=preprocessing,
432        )]
433
434        # Create the output axis description.
435        output_axes = [
436            spec.BatchAxis(),
437            spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]),
438        ]
439        output_ndim = np.load(output_tensors[0]).ndim
440        assert output_ndim in (4, 5)
441        axis_names = "zyx" if output_ndim == 5 else "yx"
442        assert len(axis_names) == len(halo)
443        output_axes += [
444            spec.SpaceOutputAxisWithHalo(
445                id=spec.AxisId(ax_name),
446                size=spec.SizeReference(
447                    tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name)
448                ),
449                halo=halo_val,
450            ) for ax_name, halo_val in zip(axis_names, halo)
451        ]
452
453        # Create the rest of the output description.
454        output_description = [spec.OutputTensorDescr(
455            id=spec.TensorId("prediction"),
456            axes=output_axes,
457            test_tensor=spec.FileDescr(source=Path(output_tensors[0]))
458        )]
459
460    else:
461        raise NotImplementedError("Model export currently only works for torch_em.model.unet.")
462
463    return input_description, output_description, notebook_link
464
465
466def _validate_model(spec_path):
467    if not os.path.exists(spec_path):
468        return False
469
470    model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True)
471    root = model_spec.root
472
473    input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs]
474    inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths]
475
476    expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs]
477    expected = [np.load(opt) for opt in expected_paths]
478
479    with torch.no_grad():
480        inputs = [torch.from_numpy(input_) for input_ in inputs]
481        outputs = model(*inputs)
482        if torch.is_tensor(outputs):
483            outputs = [outputs]
484        outputs = [out.numpy() for out in outputs]
485
486    for out, exp in zip(outputs, expected):
487        if not np.allclose(out, exp):
488            return False
489    return True
490
491
492#
493# Model Export Functionality
494#
495
496def _get_input_data(trainer):
497    loader = trainer.val_loader
498    x = next(iter(loader))[0].numpy()
499    return x
500
501
502def export_bioimageio_model(
503    checkpoint: str,
504    output_path: str,
505    input_data: Optional[np.ndarray] = None,
506    name: Optional[str] = None,
507    description: Optional[str] = None,
508    authors: Optional[List[Dict[str, str]]] = None,
509    tags: Optional[List[str]] = None,
510    license: Optional[str] = None,
511    documentation: Optional[str] = None,
512    covers: Optional[str] = None,
513    git_repo: Optional[str] = None,
514    cite: Optional[List[Dict[str, str]]] = None,
515    input_optional_parameters: bool = True,
516    model_postprocessing: Optional[str] = None,
517    for_deepimagej: bool = False,
518    links: Optional[List[str]] = None,
519    maintainers: Optional[List[Dict[str, str]]] = None,
520    min_shape: Tuple[int, ...] = None,
521    halo: Tuple[int, ...] = None,
522    checkpoint_name: str = "best",
523    config: Dict = {},
524) -> bool:
525    """Export model to bioimage.io model format.
526
527    Args:
528        checkpoint: The path to the checkpoint with the model to export.
529        output_path: The output path for saving the model.
530        input_data: The input data for creating model test data.
531        name: The export name of the model.
532        description: The description of the model.
533        authors: The authors that created this model.
534        tags: List of tags for this model.
535        license: The license under which to publish the model.
536        documentation: The documentation of the model.
537        covers: The covers to show when displaying the model.
538        git_repo: A github repository associated with this model.
539        cite: References to cite for this model.
540        input_optional_parameters: Whether to input optional parameters via the command line.
541        model_postprocessing: Postprocessing to apply to the model outputs.
542        for_deepimagej: Whether this model can be run in DeepImageJ.
543        links: Linked modelzoo apps or software for this model.
544        maintainers: The maintainers of this model.
545        min_shape: The minimal valid input shape for the model.
546        halo: The halo to cut away from model outputs.
547        checkpoint_name: The name of the model checkpoint to load for the export.
548        config: Dictionary with additional configuration for this model.
549
550    Returns:
551        Whether the exported model was successfully validated.
552    """
553    # Load the trainer and model.
554    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
555    model, model_kwargs = _get_model(trainer, model_postprocessing)
556
557    # Get input data from the trainer if it is not given.
558    if input_data is None:
559        input_data = _get_input_data(trainer)
560
561    with tempfile.TemporaryDirectory() as export_folder:
562
563        # Create the weight description.
564        weight_description = _create_weight_description(model, export_folder, model_kwargs)
565
566        # Create the test input/output files.
567        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
568        # Get the descriptions for inputs, outputs and notebook links.
569        input_description, output_description, notebook_link = _get_inout_descriptions(
570            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
571        )
572
573        # Get the additional kwargs.
574        kwargs = _get_kwargs(
575            trainer, name, description,
576            authors, tags, license, documentation,
577            git_repo, cite, maintainers,
578            export_folder, input_optional_parameters
579        )
580
581        # TODO double check the current link policy
582        # The apps to link with this model, by default ilastik.
583        if links is None:
584            links = []
585        links.append("ilastik/ilastik")
586        # add the notebook link, if available
587        if notebook_link is not None:
588            links.append(notebook_link)
589        kwargs.update({"links": links})
590
591        if covers is not None:
592            kwargs["covers"] = covers
593
594        model_description = spec.ModelDescr(
595            inputs=input_description,
596            outputs=output_description,
597            weights=weight_description,
598            config=config,
599            **kwargs,
600        )
601
602        save_bioimageio_package(model_description, output_path=output_path)
603
604    # Validate the model.
605    val_success = _validate_model(output_path)
606    if val_success:
607        print(f"The model was successfully exported to '{output_path}'.")
608    else:
609        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
610             "You can use this model, but it will probably yield incorrect results.")
611    return val_success
612
613
614# TODO support bounding boxes
615def _load_data(path, key):
616    if key is None:
617        ext = os.path.splitext(path)[-1]
618        if ext == ".npy":
619            return np.load(path)
620        else:
621            return imageio.imread(path)
622    else:
623        return open_file(path, mode="r")[key][:]
624
625
626def main():
627    """@private
628    """
629    parser = argparse.ArgumentParser(
630        "Export model trained with torch_em to the BioImage.IO model format."
631        "The exported model can be run in any tool supporting BioImage.IO."
632        "For details check out https://bioimage.io/#/."
633    )
634    parser.add_argument("-p", "--path", required=True,
635                        help="Path to the model checkpoint to export to the BioImage.IO model format.")
636    parser.add_argument("-d", "--data", required=True,
637                        help="Path to the test data to use for creating the exported model.")
638    parser.add_argument("-f", "--export_folder", required=True,
639                        help="Where to save the exported model. The exported model is stored as a zip in the folder.")
640    parser.add_argument("-k", "--key",
641                        help="The key for the test data. Required for container data formats like hdf5 or zarr.")
642    parser.add_argument("-n", "--name", help="The name of the exported model.")
643
644    args = parser.parse_args()
645    name = os.path.basename(args.path) if args.name is None else args.name
646    export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name)
647
648
649#
650# model import functionality
651#
652
653def _load_model(model_spec, device):
654    weight_spec = model_spec.weights.pytorch_state_dict
655    model = PytorchModelAdapter.get_network(weight_spec)
656    weight_file = weight_spec.source.path
657    if not os.path.exists(weight_file):
658        root_folder = f"{model_spec.root.filename}.unzip"
659        assert os.path.exists(root_folder), root_folder
660        weight_file = os.path.join(root_folder, weight_file)
661    assert os.path.exists(weight_file), weight_file
662    state = torch.load(weight_file, map_location=device, weights_only=False)
663    model.load_state_dict(state)
664    model.eval()
665    return model
666
667
668def _load_normalizer(model_spec):
669    inputs = model_spec.inputs[0]
670    preprocessing = inputs.preprocessing
671
672    # Filter out ensure dtype.
673    preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"]
674    if len(preprocessing) == 0:
675        return None
676
677    ndim = len(inputs.axes) - 2
678    shape = inputs.shape
679    if hasattr(shape, "min"):
680        shape = shape.min
681
682    conf = preprocessing[0]
683    name = conf.id
684    spec_kwargs = conf.kwargs
685
686    def _get_axis(axes):
687        label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\
688            {"channel": 0, "y": 1, "x": 2}
689        axis = tuple(label_to_id[ax] for ax in axes)
690
691        # Is the axis full? Then we don't need to specify it.
692        if len(axis) == ndim + 1:
693            return None
694
695        # Drop the channel axis if we have only a single channel.
696        # Because torch_em squeezes the channel axis in this case.
697        if shape[1] == 1:
698            axis = tuple(ax - 1 for ax in axis if ax > 0)
699        return axis
700
701    axis = _get_axis(spec_kwargs.get("axes", None))
702    if name == "zero_mean_unit_variance":
703        kwargs = {"axis": axis}
704        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
705
706    elif name == "fixed_zero_mean_unit_variance":
707        kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]}
708        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
709
710    elif name == "scale_linear":
711        min_ = -spec_kwargs["offset"]
712        max_ = 1. / spec_kwargs["gain"]
713        kwargs = {"axis": axis, "minval": min_, "maxval": max_}
714        normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs)
715
716    elif name == "scale_range":
717        assert spec_kwargs["mode"] == "per_sample"  # Can't parse the other modes right now.
718        lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"]
719        if np.isclose(lower, 0.0) and np.isclose(upper, 100.0):
720            normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis)
721        else:
722            kwargs = {"axis": axis, "lower": lower, "upper": upper}
723            normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs)
724
725    else:
726        msg = f"torch_em does not support the use of the biomageio preprocessing function {name}."
727        raise RuntimeError(msg)
728
729    return normalizer
730
731
732def import_bioimageio_model(spec_path, return_spec=False, device="cpu"):
733    """@private
734    """
735    model_spec = core.load_description(spec_path)
736
737    model = _load_model(model_spec, device=device)
738    normalizer = _load_normalizer(model_spec)
739
740    if return_spec:
741        return model, normalizer, model_spec
742    else:
743        return model, normalizer
744
745
746# TODO: the weight conversion needs to be updated once the
747# corresponding functionality in bioimageio.core is updated
748#
749# Weight Conversion
750#
751
752
753def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs):
754    with tempfile.TemporaryDirectory() as tmp_dir:
755        weight_path = os.path.join(tmp_dir, weight_name)
756        model_spec = core.load_description(spec_path)
757        weight_descr = converter(model_spec, weight_path, **kwargs)
758        # TODO double check
759        setattr(model_spec.weights, weight_type, weight_descr)
760        save_bioimageio_package(model_spec, output_path=spec_path)
761
762
763def convert_to_onnx(spec_path, opset_version=12):
764    """@private
765    """
766    raise NotImplementedError
767    # converter = weight_converter.convert_weights_to_onnx
768    # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version)
769    # return None
770
771
772def convert_to_torchscript(model_path):
773    """@private
774    """
775    raise NotImplementedError
776    # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript
777
778    # weight_name = "weights-torchscript.pt"
779    # breakpoint()
780    # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript")
781
782    # # Check that we can load the converted weights.
783    # model_spec = core.load_description(model_path)
784    # weight_path = model_spec.weights.torchscript.weights
785    # try:
786    #     torch.jit.load(weight_path)
787    #     return None
788    # except Exception as e:
789    #     return e
790
791
792def add_weight_formats(model_path, additional_formats):
793    """@private
794    """
795    for add_format in additional_formats:
796
797        if add_format == "onnx":
798            ret = convert_to_onnx(model_path)
799        elif add_format == "torchscript":
800            ret = convert_to_torchscript(model_path)
801
802        if ret is None:
803            print("Successfully added", add_format, "weights")
804        else:
805            warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.")
806
807
808def convert_main():
809    """@private
810    """
811    parser = argparse.ArgumentParser("Convert weights from native pytorch format to onnx or torchscript")
812    parser.add_argument("-f", "--model_folder", required=True, help="")
813    parser.add_argument("-w", "--weight_format", required=True, help="")
814    args = parser.parse_args()
815    weight_format = args.weight_format
816    assert weight_format in ("onnx", "torchscript")
817    if weight_format == "onnx":
818        convert_to_onnx(args.model_folder)
819    else:
820        convert_to_torchscript(args.model_folder)
821
822
823#
824# Misc Functionality
825#
826
827def export_parser_helper():
828    """@private
829    """
830    parser = argparse.ArgumentParser()
831    parser.add_argument("-c", "--checkpoint", required=True)
832    parser.add_argument("-i", "--input", required=True)
833    parser.add_argument("-o", "--output", required=True)
834    parser.add_argument("-a", "--affs_to_bd", default=0, type=int)
835    parser.add_argument("-f", "--additional_formats", type=str, nargs="+")
836    return parser
837
838
839def get_mws_config(offsets, config=None):
840    """@private
841    """
842    mws_config = {"offsets": offsets}
843    if config is None:
844        config = {"mws": mws_config}
845    else:
846        assert isinstance(config, dict)
847        config["mws"] = mws_config
848    return config
849
850
851def get_shallow2deep_config(rf_path, config=None):
852    """@private
853    """
854    if os.path.isdir(rf_path):
855        rf_path = glob(os.path.join(rf_path, "*.pkl"))[0]
856    assert os.path.exists(rf_path), rf_path
857    with open(rf_path, "rb") as f:
858        rf = pickle.load(f)
859    shallow2deep_config = {"ndim": rf.feature_ndim, "features": rf.feature_config}
860    if config is None:
861        config = {"shallow2deep": shallow2deep_config}
862    else:
863        assert isinstance(config, dict)
864        config["shallow2deep"] = shallow2deep_config
865    return config
def export_bioimageio_model( checkpoint: str, output_path: str, input_data: Optional[numpy.ndarray] = None, name: Optional[str] = None, description: Optional[str] = None, authors: Optional[List[Dict[str, str]]] = None, tags: Optional[List[str]] = None, license: Optional[str] = None, documentation: Optional[str] = None, covers: Optional[str] = None, git_repo: Optional[str] = None, cite: Optional[List[Dict[str, str]]] = None, input_optional_parameters: bool = True, model_postprocessing: Optional[str] = None, for_deepimagej: bool = False, links: Optional[List[str]] = None, maintainers: Optional[List[Dict[str, str]]] = None, min_shape: Tuple[int, ...] = None, halo: Tuple[int, ...] = None, checkpoint_name: str = 'best', config: Dict = {}) -> bool:
503def export_bioimageio_model(
504    checkpoint: str,
505    output_path: str,
506    input_data: Optional[np.ndarray] = None,
507    name: Optional[str] = None,
508    description: Optional[str] = None,
509    authors: Optional[List[Dict[str, str]]] = None,
510    tags: Optional[List[str]] = None,
511    license: Optional[str] = None,
512    documentation: Optional[str] = None,
513    covers: Optional[str] = None,
514    git_repo: Optional[str] = None,
515    cite: Optional[List[Dict[str, str]]] = None,
516    input_optional_parameters: bool = True,
517    model_postprocessing: Optional[str] = None,
518    for_deepimagej: bool = False,
519    links: Optional[List[str]] = None,
520    maintainers: Optional[List[Dict[str, str]]] = None,
521    min_shape: Tuple[int, ...] = None,
522    halo: Tuple[int, ...] = None,
523    checkpoint_name: str = "best",
524    config: Dict = {},
525) -> bool:
526    """Export model to bioimage.io model format.
527
528    Args:
529        checkpoint: The path to the checkpoint with the model to export.
530        output_path: The output path for saving the model.
531        input_data: The input data for creating model test data.
532        name: The export name of the model.
533        description: The description of the model.
534        authors: The authors that created this model.
535        tags: List of tags for this model.
536        license: The license under which to publish the model.
537        documentation: The documentation of the model.
538        covers: The covers to show when displaying the model.
539        git_repo: A github repository associated with this model.
540        cite: References to cite for this model.
541        input_optional_parameters: Whether to input optional parameters via the command line.
542        model_postprocessing: Postprocessing to apply to the model outputs.
543        for_deepimagej: Whether this model can be run in DeepImageJ.
544        links: Linked modelzoo apps or software for this model.
545        maintainers: The maintainers of this model.
546        min_shape: The minimal valid input shape for the model.
547        halo: The halo to cut away from model outputs.
548        checkpoint_name: The name of the model checkpoint to load for the export.
549        config: Dictionary with additional configuration for this model.
550
551    Returns:
552        Whether the exported model was successfully validated.
553    """
554    # Load the trainer and model.
555    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
556    model, model_kwargs = _get_model(trainer, model_postprocessing)
557
558    # Get input data from the trainer if it is not given.
559    if input_data is None:
560        input_data = _get_input_data(trainer)
561
562    with tempfile.TemporaryDirectory() as export_folder:
563
564        # Create the weight description.
565        weight_description = _create_weight_description(model, export_folder, model_kwargs)
566
567        # Create the test input/output files.
568        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
569        # Get the descriptions for inputs, outputs and notebook links.
570        input_description, output_description, notebook_link = _get_inout_descriptions(
571            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
572        )
573
574        # Get the additional kwargs.
575        kwargs = _get_kwargs(
576            trainer, name, description,
577            authors, tags, license, documentation,
578            git_repo, cite, maintainers,
579            export_folder, input_optional_parameters
580        )
581
582        # TODO double check the current link policy
583        # The apps to link with this model, by default ilastik.
584        if links is None:
585            links = []
586        links.append("ilastik/ilastik")
587        # add the notebook link, if available
588        if notebook_link is not None:
589            links.append(notebook_link)
590        kwargs.update({"links": links})
591
592        if covers is not None:
593            kwargs["covers"] = covers
594
595        model_description = spec.ModelDescr(
596            inputs=input_description,
597            outputs=output_description,
598            weights=weight_description,
599            config=config,
600            **kwargs,
601        )
602
603        save_bioimageio_package(model_description, output_path=output_path)
604
605    # Validate the model.
606    val_success = _validate_model(output_path)
607    if val_success:
608        print(f"The model was successfully exported to '{output_path}'.")
609    else:
610        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
611             "You can use this model, but it will probably yield incorrect results.")
612    return val_success

Export model to bioimage.io model format.

Arguments:
  • checkpoint: The path to the checkpoint with the model to export.
  • output_path: The output path for saving the model.
  • input_data: The input data for creating model test data.
  • name: The export name of the model.
  • description: The description of the model.
  • authors: The authors that created this model.
  • tags: List of tags for this model.
  • license: The license under which to publish the model.
  • documentation: The documentation of the model.
  • covers: The covers to show when displaying the model.
  • git_repo: A github repository associated with this model.
  • cite: References to cite for this model.
  • input_optional_parameters: Whether to input optional parameters via the command line.
  • model_postprocessing: Postprocessing to apply to the model outputs.
  • for_deepimagej: Whether this model can be run in DeepImageJ.
  • links: Linked modelzoo apps or software for this model.
  • maintainers: The maintainers of this model.
  • min_shape: The minimal valid input shape for the model.
  • halo: The halo to cut away from model outputs.
  • checkpoint_name: The name of the model checkpoint to load for the export.
  • config: Dictionary with additional configuration for this model.
Returns:

Whether the exported model was successfully validated.