torch_em.util.modelzoo

  1import argparse
  2import functools
  3import json
  4import os
  5import pickle
  6import subprocess
  7import tempfile
  8
  9from glob import glob
 10from pathlib import Path
 11from warnings import warn
 12
 13import imageio
 14import numpy as np
 15import torch
 16import torch_em
 17
 18import bioimageio.core as core
 19import bioimageio.spec.model.v0_5 as spec
 20from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
 21from bioimageio.spec import save_bioimageio_package
 22
 23from elf.io import open_file
 24from .util import get_trainer, get_normalizer
 25
 26
 27#
 28# General Purpose Functionality
 29#
 30
 31
 32def normalize_with_batch(data, normalizer):
 33    if normalizer is None:
 34        return data
 35    normalized = np.concatenate(
 36        [normalizer(da)[None] for da in data],
 37        axis=0
 38    )
 39    return normalized
 40
 41
 42#
 43# Utility Functions for Model Export.
 44#
 45
 46
 47def get_default_citations(model=None, model_output=None):
 48    citations = [
 49        {"text": "training library", "doi": "10.5281/zenodo.5108853"}
 50    ]
 51
 52    # try to derive the correct network citation from the model class
 53    if model is not None:
 54        if isinstance(model, str):
 55            model_name = model
 56        else:
 57            model_name = str(model.__class__.__name__)
 58
 59        if model_name.lower() in ("unet2d", "unet_2d", "unet"):
 60            citations.append(
 61                {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"}
 62            )
 63        elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"):
 64            citations.append(
 65                {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"}
 66            )
 67        else:
 68            warn("No citation for architecture {model_name} found.")
 69
 70    # try to derive the correct segmentation algo citation from the model output type
 71    if model_output is not None:
 72        msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported."
 73        if model_output == "affinities":
 74            citations.append(
 75                {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"}
 76            )
 77        elif model_output == "boundaries":
 78            citations.append(
 79                {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"}
 80            )
 81        else:
 82            warn(msg)
 83
 84    return citations
 85
 86
 87def _get_model(trainer, postprocessing):
 88    model = trainer.model
 89    model.eval()
 90    model_kwargs = model.init_kwargs
 91    # clear the kwargs of non builtins
 92    # TODO warn if we strip any non-standard arguments
 93    model_kwargs = {k: v for k, v in model_kwargs.items() if not isinstance(v, type)}
 94
 95    # set the in-model postprocessing if given
 96    if postprocessing is not None:
 97        assert "postprocessing" in model_kwargs
 98        model_kwargs["postprocessing"] = postprocessing
 99        state = model.state_dict()
100        model = model.__class__(**model_kwargs)
101        model.load_state_dict(state)
102        model.eval()
103
104    return model, model_kwargs
105
106
107def _pad(input_data, trainer):
108    try:
109        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
110            ndim = trainer.train_loader.dataset.dataset.ndim
111        else:
112            ndim = trainer.train_loader.dataset.ndim
113    except AttributeError:
114        ndim = trainer.train_loader.dataset.datasets[0].ndim
115    target_dims = ndim + 2
116    for _ in range(target_dims - input_data.ndim):
117        input_data = np.expand_dims(input_data, axis=0)
118    return input_data
119
120
121def _write_data(input_data, model, trainer, export_folder):
122    # if input_data is None:
123    #     gen = SampleGenerator(trainer, 1, False, 1)
124    #     input_data = next(gen)
125    if isinstance(input_data, np.ndarray):
126        input_data = [input_data]
127
128    # normalize the input data if we have a normalization function
129    normalizer = get_normalizer(trainer)
130
131    # pad to 4d/5d and normalize the input data
132    # NOTE we have to save the padded data, but without normalization
133    test_inputs = [_pad(input_, trainer) for input_ in input_data]
134    normalized = [normalize_with_batch(input_, normalizer) for input_ in test_inputs]
135
136    # run prediction
137    with torch.no_grad():
138        test_tensors = [torch.from_numpy(norm).to(trainer.device) for norm in normalized]
139        test_outputs = model(*test_tensors)
140        if torch.is_tensor(test_outputs):
141            test_outputs = [test_outputs]
142        test_outputs = [out.cpu().numpy() for out in test_outputs]
143
144    # save the input / output
145    test_in_paths, test_out_paths = [], []
146    for i, input_ in enumerate(test_inputs):
147        test_in_path = os.path.join(export_folder, f"test_input_{i}.npy")
148        np.save(test_in_path, input_)
149        test_in_paths.append(test_in_path)
150    for i, out in enumerate(test_outputs):
151        test_out_path = os.path.join(export_folder, f"test_output_{i}.npy")
152        np.save(test_out_path, out)
153        test_out_paths.append(test_out_path)
154    return test_in_paths, test_out_paths
155
156
157def _create_weight_description(model, export_folder, model_kwargs):
158    module = str(model.__class__.__module__)
159    cls_name = str(model.__class__.__name__)
160
161    if module == "torch_em.model.unet":
162        source_path = os.path.join(os.path.split(__file__)[0], "../model/unet.py")
163        architecture = spec.ArchitectureFromFileDescr(
164            source=Path(source_path),
165            callable=cls_name,
166            kwargs=model_kwargs,
167        )
168    else:
169        architecture = spec.ArchitectureFromLibraryDescr(
170            import_from=module,
171            callable=cls_name,
172            kwargs=model_kwargs,
173        )
174
175    checkpoint_path = os.path.join(export_folder, "state_dict.pt")
176    torch.save(model.state_dict(), checkpoint_path)
177
178    weight_description = spec.WeightsDescr(
179        pytorch_state_dict=spec.PytorchStateDictWeightsDescr(
180            source=Path(checkpoint_path),
181            architecture=architecture,
182            pytorch_version=spec.Version(torch.__version__),
183        )
184    )
185    return weight_description
186
187
188def _get_kwargs(
189    trainer, name, description, authors, tags, license, documentation,
190    git_repo, cite, maintainers, export_folder, input_optional_parameters
191):
192    if input_optional_parameters:
193        print("Enter values for the optional parameters.")
194        print("If the default value in [] is satisfactory, press enter without additional input.")
195        print("Please enter lists using json syntax.")
196
197    def _get_kwarg(kwarg_name, val, default, is_list=False, fname=None):
198        # if we don"t have a value, we either ask user for input (offering the default)
199        # or just use the default if input_optional_parameters is False
200        if val is None and input_optional_parameters:
201            default_val = default()
202            choice = input(f"{kwarg_name} [{default_val}]: ")
203            val = choice if choice else default_val
204        elif val is None:
205            val = default()
206
207        if fname is not None:
208            save_path = os.path.join(export_folder, fname)
209            with open(save_path, "w") as f:
210                f.write(val)
211            return save_path
212
213        if is_list and isinstance(val, str):
214            val = val.replace(""", """)  # enable single quotes
215            val = json.loads(val)
216        if is_list:
217            assert isinstance(val, (list, tuple)), type(val)
218        return val
219
220    def _default_authors():
221        # first try to derive the author name from git
222        try:
223            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
224            author = call_res.stdout.decode("utf8").rstrip("\n")
225            author = author if author else None  # in case there was no error, but output is empty
226        except Exception:
227            author = None
228
229        # otherwise use the username
230        if author is None:
231            author = os.uname()[1]
232
233        return [{"name": author}]
234
235    def _default_repo():
236        return None
237        try:
238            call_res = subprocess.run(["git", "remote", "-v"], capture_output=True)
239            repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1]
240            repo = repo if repo else None
241        except Exception:
242            repo = None
243        return repo
244
245    def _default_maintainers():
246        # first try to derive the maintainer name from git
247        try:
248            call_res = subprocess.run(["git", "config", "user.name"], capture_output=True)
249            maintainer = call_res.stdout.decode("utf8").rstrip("\n")
250            maintainer = maintainer if maintainer else None  # in case there was no error, but output is empty
251        except Exception:
252            maintainer = None
253
254        # otherwise use the username
255        if maintainer is None:
256            maintainer = os.uname()[1]
257
258        return [{"github_user": maintainer}]
259
260    # TODO derive better default values:
261    # - description: derive something from trainer.ndim, trainer.loss, trainer.model, ...
262    # - tags: derive something from trainer.ndim, trainer.loss, trainer.model, ...
263    # - documentation: derive something from trainer.ndim, trainer.loss, trainer.model, ...
264    kwargs = {
265        "name": _get_kwarg("name", name, lambda: trainer.name),
266        "description": _get_kwarg("description", description, lambda: trainer.name),
267        "authors": _get_kwarg("authors", authors, _default_authors, is_list=True),
268        "tags": _get_kwarg("tags", tags, lambda: [trainer.name], is_list=True),
269        "license": _get_kwarg("license", license, lambda: "MIT"),
270        "documentation": _get_kwarg(
271            "documentation", documentation, lambda: trainer.name, fname="documentation.md"
272        ),
273        "git_repo": _get_kwarg("git_repo", git_repo, _default_repo),
274        "cite": _get_kwarg("cite", cite, get_default_citations),
275        "maintainers": _get_kwarg("maintainers", maintainers, _default_maintainers, is_list=True),
276    }
277
278    return kwargs
279
280
281def _get_preprocessing(trainer):
282    try:
283        if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset):
284            ndim = trainer.train_loader.dataset.dataset.ndim
285        else:
286            ndim = trainer.train_loader.dataset.ndim
287    except AttributeError:
288        ndim = trainer.train_loader.dataset.datasets[0].ndim
289    normalizer = get_normalizer(trainer)
290
291    if isinstance(normalizer, functools.partial):
292        kwargs = normalizer.keywords
293        normalizer = normalizer.func
294    else:
295        kwargs = {}
296
297    def _get_axes(axis):
298        all_axes = ["channel", "y", "x"] if ndim == 2 else ["channel", "z", "y", "x"]
299        if axis is None:
300            axes = all_axes
301        else:
302            axes = [all_axes[i] for i in axes]
303        return axes
304
305    name = f"{normalizer.__module__}.{normalizer.__name__}"
306    axes = _get_axes(kwargs.get("axis", None))
307
308    if name == "torch_em.transform.raw.normalize":
309
310        min_, max_ = kwargs.get("minval", None), kwargs.get("maxval", None)
311        assert (min_ is None) == (max_ is None)
312
313        if min_ is None:
314            spec_name = "scale_range",
315            spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": 0.0, "max_percentile": 100.0}
316        else:
317            spec_name = "scale_linear"
318            spec_kwargs = {"gain": 1.0 / max_, "offset": -min_, "axes": axes}
319
320    elif name == "torch_em.transform.raw.standardize":
321        spec_kwargs = {"axes": axes}
322        mean, std = kwargs.get("mean", None), kwargs.get("std", None)
323        if (mean is None) and (std is None):
324            spec_name = "zero_mean_unit_variance"
325        else:
326            spec_name = "fixed_zero_mean_unit_varaince"
327            spec_kwargs.update({"mean": mean, "std": std})
328
329    elif name == "torch_em.transform.raw.normalize_percentile":
330        lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0)
331        spec_name = "scale_range"
332        spec_kwargs = {"mode": "per_sample", "axes": axes, "min_percentile": lower, "max_percentile": upper}
333
334    else:
335        warn(f"Could not parse the normalization function {name}, 'preprocessing' field will be empty.")
336        return None
337
338    name_to_cls = {
339        "scale_linear": spec.ScaleLinearDescr,
340        "scale_rage": spec.ScaleRangeDescr,
341        "zero_mean_unit_variance": spec.ZeroMeanUnitVarianceDescr,
342        "fixed_zero_mean_unit_variance": spec.FixedZeroMeanUnitVarianceDescr,
343    }
344    preprocessing = name_to_cls[spec_name](kwargs=spec_kwargs)
345
346    return [preprocessing]
347
348
349def _get_inout_descriptions(trainer, model, model_kwargs, input_tensors, output_tensors, min_shape, halo):
350
351    notebook_link = None
352    module = str(model.__class__.__module__)
353    name = str(model.__class__.__name__)
354
355    # can derive tensor kwargs only for known torch_em models (only unet for now)
356    if module == "torch_em.model.unet":
357        assert len(input_tensors) == len(output_tensors) == 1
358        inc, outc = model_kwargs["in_channels"], model_kwargs["out_channels"]
359
360        postprocessing = model_kwargs.get("postprocessing", None)
361        if isinstance(postprocessing, str) and postprocessing.startswith("affinities_to_boundaries"):
362            outc = 1
363        elif isinstance(postprocessing, str) and postprocessing.startswith("affinities_with_foreground_to_boundaries"):
364            outc = 2
365        elif postprocessing is not None:
366            warn(f"The model has the post-processing {postprocessing} which cannot be interpreted")
367
368        if name == "UNet2d":
369            depth = model_kwargs["depth"]
370            step = [2 ** depth] * 2
371            if min_shape is None:
372                min_shape = [2 ** (depth + 1)] * 2
373            else:
374                assert len(min_shape) == 2
375                min_shape = list(min_shape)
376            notebook_link = "ilastik/torch-em-2d-unet-notebook"
377
378        elif name == "UNet3d":
379            depth = model_kwargs["depth"]
380            step = [2 ** depth] * 3
381            if min_shape is None:
382                min_shape = [2 ** (depth + 1)] * 3
383            else:
384                assert len(min_shape) == 3
385                min_shape = list(min_shape)
386            notebook_link = "ilastik/torch-em-3d-unet-notebook"
387
388        elif name == "AnisotropicUNet":
389            scale_factors = model_kwargs["scale_factors"]
390            scale_prod = [
391                int(np.prod([scale_factors[i][d] for i in range(len(scale_factors))]))
392                for d in range(3)
393            ]
394            assert len(scale_prod) == 3
395            step = scale_prod
396            if min_shape is None:
397                min_shape = [2 * sp for sp in scale_prod]
398            else:
399                min_shape = list(min_shape)
400            notebook_link = "ilastik/torch-em-3d-unet-notebook"
401
402        else:
403            raise RuntimeError(f"Cannot derive tensor parameters for {module}.{name}")
404
405        if halo is None:   # default halo = step // 2
406            halo = [st // 2 for st in step]
407        else:  # make sure the passed halo has the same length as step, by padding with zeros
408            halo = [0] * (len(step) - len(halo)) + halo
409        assert len(halo) == len(step), f"{len(halo)}, {len(step)}"
410
411        # Create the input axis description.
412        input_axes = [
413            spec.BatchAxis(),
414            spec.ChannelAxis(channel_names=[spec.Identifier(f"channel_{i}") for i in range(inc)]),
415        ]
416        input_ndim = np.load(input_tensors[0]).ndim
417        assert input_ndim in (4, 5)
418        axis_names = "zyx" if input_ndim == 5 else "yx"
419        assert len(axis_names) == len(min_shape) == len(step)
420        input_axes += [
421            spec.SpaceInputAxis(id=spec.AxisId(ax_name), size=spec.ParameterizedSize(min=ax_min, step=ax_step))
422            for ax_name, ax_min, ax_step in zip(axis_names, min_shape, step)
423        ]
424
425        # Create the rest of the input description.
426        preprocessing = _get_preprocessing(trainer)
427        input_description = [spec.InputTensorDescr(
428            id=spec.TensorId("image"),
429            axes=input_axes,
430            test_tensor=spec.FileDescr(source=Path(input_tensors[0])),
431            preprocessing=preprocessing,
432        )]
433
434        # Create the output axis description.
435        output_axes = [
436            spec.BatchAxis(),
437            spec.ChannelAxis(channel_names=[spec.Identifier(f"out_channel_{i}") for i in range(outc)]),
438        ]
439        output_ndim = np.load(output_tensors[0]).ndim
440        assert output_ndim in (4, 5)
441        axis_names = "zyx" if output_ndim == 5 else "yx"
442        assert len(axis_names) == len(halo)
443        output_axes += [
444            spec.SpaceOutputAxisWithHalo(
445                id=spec.AxisId(ax_name),
446                size=spec.SizeReference(
447                    tensor_id=spec.TensorId("image"), axis_id=spec.AxisId(ax_name)
448                ),
449                halo=halo_val,
450            ) for ax_name, halo_val in zip(axis_names, halo)
451        ]
452
453        # Create the rest of the output description.
454        output_description = [spec.OutputTensorDescr(
455            id=spec.TensorId("prediction"),
456            axes=output_axes,
457            test_tensor=spec.FileDescr(source=Path(output_tensors[0]))
458        )]
459
460    else:
461        raise NotImplementedError("Model export currently only works for torch_em.model.unet.")
462
463    return input_description, output_description, notebook_link
464
465
466def _validate_model(spec_path):
467    if not os.path.exists(spec_path):
468        return False
469
470    model, normalizer, model_spec = import_bioimageio_model(spec_path, return_spec=True)
471    root = model_spec.root
472
473    input_paths = [os.path.join(root, ipt.test_tensor.source.path) for ipt in model_spec.inputs]
474    inputs = [normalize_with_batch(np.load(ipt), normalizer) for ipt in input_paths]
475
476    expected_paths = [os.path.join(root, opt.test_tensor.source.path) for opt in model_spec.outputs]
477    expected = [np.load(opt) for opt in expected_paths]
478
479    with torch.no_grad():
480        inputs = [torch.from_numpy(input_) for input_ in inputs]
481        outputs = model(*inputs)
482        if torch.is_tensor(outputs):
483            outputs = [outputs]
484        outputs = [out.numpy() for out in outputs]
485
486    for out, exp in zip(outputs, expected):
487        if not np.allclose(out, exp):
488            return False
489    return True
490
491
492#
493# Model Export Functionality
494#
495
496def _get_input_data(trainer):
497    loader = trainer.val_loader
498    x = next(iter(loader))[0].numpy()
499    return x
500
501
502# TODO config: training details derived from loss and optimizer, custom params, e.g. offsets for mws
503def export_bioimageio_model(
504    checkpoint,
505    output_path,
506    input_data=None,
507    name=None,
508    description=None,
509    authors=None,
510    tags=None,
511    license=None,
512    documentation=None,
513    covers=None,
514    git_repo=None,
515    cite=None,
516    input_optional_parameters=True,
517    model_postprocessing=None,
518    for_deepimagej=False,
519    links=None,
520    maintainers=None,
521    min_shape=None,
522    halo=None,
523    checkpoint_name="best",
524    training_data=None,
525    config={}
526):
527    """Export model to bioimage.io model format.
528    """
529    # Load the trainer and model.
530    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
531    model, model_kwargs = _get_model(trainer, model_postprocessing)
532
533    # Get input data from the trainer if it is not given.
534    if input_data is None:
535        input_data = _get_input_data(trainer)
536
537    with tempfile.TemporaryDirectory() as export_folder:
538
539        # Create the weight description.
540        weight_description = _create_weight_description(model, export_folder, model_kwargs)
541
542        # Create the test input/output files.
543        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
544        # Get the descriptions for inputs, outputs and notebook links.
545        input_description, output_description, notebook_link = _get_inout_descriptions(
546            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
547        )
548
549        # Get the additional kwargs.
550        kwargs = _get_kwargs(
551            trainer, name, description,
552            authors, tags, license, documentation,
553            git_repo, cite, maintainers,
554            export_folder, input_optional_parameters
555        )
556
557        # TODO double check the current link policy
558        # The apps to link with this model, by default ilastik.
559        if links is None:
560            links = []
561        links.append("ilastik/ilastik")
562        # add the notebook link, if available
563        if notebook_link is not None:
564            links.append(notebook_link)
565        kwargs.update({"links": links})
566
567        if covers is not None:
568            kwargs["covers"] = covers
569
570        model_description = spec.ModelDescr(
571            inputs=input_description,
572            outputs=output_description,
573            weights=weight_description,
574            config=config,
575            **kwargs,
576        )
577
578        save_bioimageio_package(model_description, output_path=output_path)
579
580    # Validate the model.
581    val_success = _validate_model(output_path)
582    if val_success:
583        print(f"The model was successfully exported to '{output_path}'.")
584    else:
585        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
586             "You can use this model, but it will probably yield incorrect results.")
587    return val_success
588
589
590# TODO support bounding boxes
591def _load_data(path, key):
592    if key is None:
593        ext = os.path.splitext(path)[-1]
594        if ext == ".npy":
595            return np.load(path)
596        else:
597            return imageio.imread(path)
598    else:
599        return open_file(path, mode="r")[key][:]
600
601
602def main():
603    import argparse
604    parser = argparse.ArgumentParser(
605        "Export model trained with torch_em to the BioImage.IO model format."
606        "The exported model can be run in any tool supporting BioImage.IO."
607        "For details check out https://bioimage.io/#/."
608    )
609    parser.add_argument("-p", "--path", required=True,
610                        help="Path to the model checkpoint to export to the BioImage.IO model format.")
611    parser.add_argument("-d", "--data", required=True,
612                        help="Path to the test data to use for creating the exported model.")
613    parser.add_argument("-f", "--export_folder", required=True,
614                        help="Where to save the exported model. The exported model is stored as a zip in the folder.")
615    parser.add_argument("-k", "--key",
616                        help="The key for the test data. Required for container data formats like hdf5 or zarr.")
617    parser.add_argument("-n", "--name", help="The name of the exported model.")
618
619    args = parser.parse_args()
620    name = os.path.basename(args.path) if args.name is None else args.name
621    export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name)
622
623
624#
625# model import functionality
626#
627
628def _load_model(model_spec, device):
629    weight_spec = model_spec.weights.pytorch_state_dict
630    model = PytorchModelAdapter.get_network(weight_spec)
631    weight_file = weight_spec.source.path
632    if not os.path.exists(weight_file):
633        weight_file = os.path.join(model_spec.root, weight_file)
634    assert os.path.exists(weight_file), weight_file
635    state = torch.load(weight_file, map_location=device)
636    model.load_state_dict(state)
637    model.eval()
638    return model
639
640
641def _load_normalizer(model_spec):
642    inputs = model_spec.inputs[0]
643    preprocessing = inputs.preprocessing
644
645    # Filter out ensure dtype.
646    preprocessing = [preproc for preproc in preprocessing if preproc.id != "ensure_dtype"]
647    if len(preprocessing) == 0:
648        return None
649
650    ndim = len(inputs.axes) - 2
651    shape = inputs.shape
652    if hasattr(shape, "min"):
653        shape = shape.min
654
655    conf = preprocessing[0]
656    name = conf.id
657    spec_kwargs = conf.kwargs
658
659    def _get_axis(axes):
660        label_to_id = {"channel": 0, "z": 1, "y": 2, "x": 3} if ndim == 3 else\
661            {"channel": 0, "y": 1, "x": 2}
662        axis = tuple(label_to_id[ax] for ax in axes)
663
664        # Is the axis full? Then we don't need to specify it.
665        if len(axis) == ndim + 1:
666            return None
667
668        # Drop the channel axis if we have only a single channel.
669        # Because torch_em squeezes the channel axis in this case.
670        if shape[1] == 1:
671            axis = tuple(ax - 1 for ax in axis if ax > 0)
672        return axis
673
674    axis = _get_axis(spec_kwargs.get("axes", None))
675    if name == "zero_mean_unit_variance":
676        kwargs = {"axis": axis}
677        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
678
679    elif name == "fixed_zero_mean_unit_variance":
680        kwargs = {"axis": axis, "mean": spec_kwargs["mean"], "std": spec_kwargs["std"]}
681        normalizer = functools.partial(torch_em.transform.raw.standardize, **kwargs)
682
683    elif name == "scale_linear":
684        min_ = -spec_kwargs["offset"]
685        max_ = 1. / spec_kwargs["gain"]
686        kwargs = {"axis": axis, "minval": min_, "maxval": max_}
687        normalizer = functools.partial(torch_em.transform.raw.normalize, **kwargs)
688
689    elif name == "scale_range":
690        assert spec_kwargs["mode"] == "per_sample"  # Can't parse the other modes right now.
691        lower, upper = spec_kwargs["min_percentile"], spec_kwargs["max_percentile"]
692        if np.isclose(lower, 0.0) and np.isclose(upper, 100.0):
693            normalizer = functools.partial(torch_em.transform.raw.normalize, axis=axis)
694        else:
695            kwargs = {"axis": axis, "lower": lower, "upper": upper}
696            normalizer = functools.partial(torch_em.transform.raw.normalize_percentile, **kwargs)
697
698    else:
699        msg = f"torch_em does not support the use of the biomageio preprocessing function {name}."
700        raise RuntimeError(msg)
701
702    return normalizer
703
704
705def import_bioimageio_model(spec_path, return_spec=False, device="cpu"):
706    model_spec = core.load_description(spec_path)
707
708    model = _load_model(model_spec, device=device)
709    normalizer = _load_normalizer(model_spec)
710
711    if return_spec:
712        return model, normalizer, model_spec
713    else:
714        return model, normalizer
715
716
717# TODO
718def import_trainer_from_bioimageio_model(spec_path):
719    pass
720
721
722# TODO: the weight conversion needs to be updated once the
723# corresponding functionality in bioimageio.core is updated
724#
725# Weight Conversion
726#
727
728
729def _convert_impl(spec_path, weight_name, converter, weight_type, **kwargs):
730    with tempfile.TemporaryDirectory() as tmp_dir:
731        weight_path = os.path.join(tmp_dir, weight_name)
732        model_spec = core.load_description(spec_path)
733        weight_descr = converter(model_spec, weight_path, **kwargs)
734        # TODO double check
735        setattr(model_spec.weights, weight_type, weight_descr)
736        save_bioimageio_package(model_spec, output_path=spec_path)
737
738
739def convert_to_onnx(spec_path, opset_version=12):
740    raise NotImplementedError
741    # converter = weight_converter.convert_weights_to_onnx
742    # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version)
743    # return None
744
745
746def convert_to_torchscript(model_path):
747    raise NotImplementedError
748    # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript
749
750    # weight_name = "weights-torchscript.pt"
751    # breakpoint()
752    # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript")
753
754    # # Check that we can load the converted weights.
755    # model_spec = core.load_description(model_path)
756    # weight_path = model_spec.weights.torchscript.weights
757    # try:
758    #     torch.jit.load(weight_path)
759    #     return None
760    # except Exception as e:
761    #     return e
762
763
764def add_weight_formats(model_path, additional_formats):
765    for add_format in additional_formats:
766
767        if add_format == "onnx":
768            ret = convert_to_onnx(model_path)
769        elif add_format == "torchscript":
770            ret = convert_to_torchscript(model_path)
771
772        if ret is None:
773            print("Successfully added", add_format, "weights")
774        else:
775            warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.")
776
777
778def convert_main():
779    import argparse
780    parser = argparse.ArgumentParser(
781        "Convert weights from native pytorch format to onnx or torchscript"
782    )
783    parser.add_argument("-f", "--model_folder", required=True,
784                        help="")
785    parser.add_argument("-w", "--weight_format", required=True,
786                        help="")
787    args = parser.parse_args()
788    weight_format = args.weight_format
789    assert weight_format in ("onnx", "torchscript")
790    if weight_format == "onnx":
791        convert_to_onnx(args.model_folder)
792    else:
793        convert_to_torchscript(args.model_folder)
794
795
796#
797# Misc Functionality
798#
799
800def export_parser_helper():
801    parser = argparse.ArgumentParser()
802    parser.add_argument("-c", "--checkpoint", required=True)
803    parser.add_argument("-i", "--input", required=True)
804    parser.add_argument("-o", "--output", required=True)
805    parser.add_argument("-a", "--affs_to_bd", default=0, type=int)
806    parser.add_argument("-f", "--additional_formats", type=str, nargs="+")
807    return parser
808
809
810def get_mws_config(offsets, config=None):
811    mws_config = {"offsets": offsets}
812    if config is None:
813        config = {"mws": mws_config}
814    else:
815        assert isinstance(config, dict)
816        config["mws"] = mws_config
817    return config
818
819
820def get_shallow2deep_config(rf_path, config=None):
821    if os.path.isdir(rf_path):
822        rf_path = glob(os.path.join(rf_path, "*.pkl"))[0]
823    assert os.path.exists(rf_path), rf_path
824    with open(rf_path, "rb") as f:
825        rf = pickle.load(f)
826    shallow2deep_config = {
827        "ndim": rf.feature_ndim,
828        "features": rf.feature_config,
829    }
830    if config is None:
831        config = {"shallow2deep": shallow2deep_config}
832    else:
833        assert isinstance(config, dict)
834        config["shallow2deep"] = shallow2deep_config
835    return config
def normalize_with_batch(data, normalizer):
33def normalize_with_batch(data, normalizer):
34    if normalizer is None:
35        return data
36    normalized = np.concatenate(
37        [normalizer(da)[None] for da in data],
38        axis=0
39    )
40    return normalized
def get_default_citations(model=None, model_output=None):
48def get_default_citations(model=None, model_output=None):
49    citations = [
50        {"text": "training library", "doi": "10.5281/zenodo.5108853"}
51    ]
52
53    # try to derive the correct network citation from the model class
54    if model is not None:
55        if isinstance(model, str):
56            model_name = model
57        else:
58            model_name = str(model.__class__.__name__)
59
60        if model_name.lower() in ("unet2d", "unet_2d", "unet"):
61            citations.append(
62                {"text": "architecture", "doi": "10.1007/978-3-319-24574-4_28"}
63            )
64        elif model_name.lower() in ("unet3d", "unet_3d", "anisotropicunet"):
65            citations.append(
66                {"text": "architecture", "doi": "10.1007/978-3-319-46723-8_49"}
67            )
68        else:
69            warn("No citation for architecture {model_name} found.")
70
71    # try to derive the correct segmentation algo citation from the model output type
72    if model_output is not None:
73        msg = f"No segmentation algorithm for output {model_output} known. 'affinities' and 'boundaries' are supported."
74        if model_output == "affinities":
75            citations.append(
76                {"text": "segmentation algorithm", "doi": "10.1109/TPAMI.2020.2980827"}
77            )
78        elif model_output == "boundaries":
79            citations.append(
80                {"text": "segmentation algorithm", "doi": "10.1038/nmeth.4151"}
81            )
82        else:
83            warn(msg)
84
85    return citations
def export_bioimageio_model( checkpoint, output_path, input_data=None, name=None, description=None, authors=None, tags=None, license=None, documentation=None, covers=None, git_repo=None, cite=None, input_optional_parameters=True, model_postprocessing=None, for_deepimagej=False, links=None, maintainers=None, min_shape=None, halo=None, checkpoint_name='best', training_data=None, config={}):
504def export_bioimageio_model(
505    checkpoint,
506    output_path,
507    input_data=None,
508    name=None,
509    description=None,
510    authors=None,
511    tags=None,
512    license=None,
513    documentation=None,
514    covers=None,
515    git_repo=None,
516    cite=None,
517    input_optional_parameters=True,
518    model_postprocessing=None,
519    for_deepimagej=False,
520    links=None,
521    maintainers=None,
522    min_shape=None,
523    halo=None,
524    checkpoint_name="best",
525    training_data=None,
526    config={}
527):
528    """Export model to bioimage.io model format.
529    """
530    # Load the trainer and model.
531    trainer = get_trainer(checkpoint, name=checkpoint_name, device="cpu")
532    model, model_kwargs = _get_model(trainer, model_postprocessing)
533
534    # Get input data from the trainer if it is not given.
535    if input_data is None:
536        input_data = _get_input_data(trainer)
537
538    with tempfile.TemporaryDirectory() as export_folder:
539
540        # Create the weight description.
541        weight_description = _create_weight_description(model, export_folder, model_kwargs)
542
543        # Create the test input/output files.
544        test_in_paths, test_out_paths = _write_data(input_data, model, trainer, export_folder)
545        # Get the descriptions for inputs, outputs and notebook links.
546        input_description, output_description, notebook_link = _get_inout_descriptions(
547            trainer, model, model_kwargs, test_in_paths, test_out_paths, min_shape, halo
548        )
549
550        # Get the additional kwargs.
551        kwargs = _get_kwargs(
552            trainer, name, description,
553            authors, tags, license, documentation,
554            git_repo, cite, maintainers,
555            export_folder, input_optional_parameters
556        )
557
558        # TODO double check the current link policy
559        # The apps to link with this model, by default ilastik.
560        if links is None:
561            links = []
562        links.append("ilastik/ilastik")
563        # add the notebook link, if available
564        if notebook_link is not None:
565            links.append(notebook_link)
566        kwargs.update({"links": links})
567
568        if covers is not None:
569            kwargs["covers"] = covers
570
571        model_description = spec.ModelDescr(
572            inputs=input_description,
573            outputs=output_description,
574            weights=weight_description,
575            config=config,
576            **kwargs,
577        )
578
579        save_bioimageio_package(model_description, output_path=output_path)
580
581    # Validate the model.
582    val_success = _validate_model(output_path)
583    if val_success:
584        print(f"The model was successfully exported to '{output_path}'.")
585    else:
586        warn(f"Validation of the bioimageio model exported to '{output_path}' has failed. " +
587             "You can use this model, but it will probably yield incorrect results.")
588    return val_success

Export model to bioimage.io model format.

def main():
603def main():
604    import argparse
605    parser = argparse.ArgumentParser(
606        "Export model trained with torch_em to the BioImage.IO model format."
607        "The exported model can be run in any tool supporting BioImage.IO."
608        "For details check out https://bioimage.io/#/."
609    )
610    parser.add_argument("-p", "--path", required=True,
611                        help="Path to the model checkpoint to export to the BioImage.IO model format.")
612    parser.add_argument("-d", "--data", required=True,
613                        help="Path to the test data to use for creating the exported model.")
614    parser.add_argument("-f", "--export_folder", required=True,
615                        help="Where to save the exported model. The exported model is stored as a zip in the folder.")
616    parser.add_argument("-k", "--key",
617                        help="The key for the test data. Required for container data formats like hdf5 or zarr.")
618    parser.add_argument("-n", "--name", help="The name of the exported model.")
619
620    args = parser.parse_args()
621    name = os.path.basename(args.path) if args.name is None else args.name
622    export_bioimageio_model(args.path, args.export_folder, _load_data(args.data, args.key), name=name)
def import_bioimageio_model(spec_path, return_spec=False, device='cpu'):
706def import_bioimageio_model(spec_path, return_spec=False, device="cpu"):
707    model_spec = core.load_description(spec_path)
708
709    model = _load_model(model_spec, device=device)
710    normalizer = _load_normalizer(model_spec)
711
712    if return_spec:
713        return model, normalizer, model_spec
714    else:
715        return model, normalizer
def import_trainer_from_bioimageio_model(spec_path):
719def import_trainer_from_bioimageio_model(spec_path):
720    pass
def convert_to_onnx(spec_path, opset_version=12):
740def convert_to_onnx(spec_path, opset_version=12):
741    raise NotImplementedError
742    # converter = weight_converter.convert_weights_to_onnx
743    # _convert_impl(spec_path, "weights.onnx", converter, "onnx", opset_version=opset_version)
744    # return None
def convert_to_torchscript(model_path):
747def convert_to_torchscript(model_path):
748    raise NotImplementedError
749    # from bioimageio.core.weight_converter.torch._torchscript import convert_weights_to_torchscript
750
751    # weight_name = "weights-torchscript.pt"
752    # breakpoint()
753    # _convert_impl(model_path, weight_name, convert_weights_to_torchscript, "torchscript")
754
755    # # Check that we can load the converted weights.
756    # model_spec = core.load_description(model_path)
757    # weight_path = model_spec.weights.torchscript.weights
758    # try:
759    #     torch.jit.load(weight_path)
760    #     return None
761    # except Exception as e:
762    #     return e
def add_weight_formats(model_path, additional_formats):
765def add_weight_formats(model_path, additional_formats):
766    for add_format in additional_formats:
767
768        if add_format == "onnx":
769            ret = convert_to_onnx(model_path)
770        elif add_format == "torchscript":
771            ret = convert_to_torchscript(model_path)
772
773        if ret is None:
774            print("Successfully added", add_format, "weights")
775        else:
776            warn(f"Added {add_format} weights, but got exception {ret} when loading the weights again.")
def convert_main():
779def convert_main():
780    import argparse
781    parser = argparse.ArgumentParser(
782        "Convert weights from native pytorch format to onnx or torchscript"
783    )
784    parser.add_argument("-f", "--model_folder", required=True,
785                        help="")
786    parser.add_argument("-w", "--weight_format", required=True,
787                        help="")
788    args = parser.parse_args()
789    weight_format = args.weight_format
790    assert weight_format in ("onnx", "torchscript")
791    if weight_format == "onnx":
792        convert_to_onnx(args.model_folder)
793    else:
794        convert_to_torchscript(args.model_folder)
def export_parser_helper():
801def export_parser_helper():
802    parser = argparse.ArgumentParser()
803    parser.add_argument("-c", "--checkpoint", required=True)
804    parser.add_argument("-i", "--input", required=True)
805    parser.add_argument("-o", "--output", required=True)
806    parser.add_argument("-a", "--affs_to_bd", default=0, type=int)
807    parser.add_argument("-f", "--additional_formats", type=str, nargs="+")
808    return parser
def get_mws_config(offsets, config=None):
811def get_mws_config(offsets, config=None):
812    mws_config = {"offsets": offsets}
813    if config is None:
814        config = {"mws": mws_config}
815    else:
816        assert isinstance(config, dict)
817        config["mws"] = mws_config
818    return config
def get_shallow2deep_config(rf_path, config=None):
821def get_shallow2deep_config(rf_path, config=None):
822    if os.path.isdir(rf_path):
823        rf_path = glob(os.path.join(rf_path, "*.pkl"))[0]
824    assert os.path.exists(rf_path), rf_path
825    with open(rf_path, "rb") as f:
826        rf = pickle.load(f)
827    shallow2deep_config = {
828        "ndim": rf.feature_ndim,
829        "features": rf.feature_config,
830    }
831    if config is None:
832        config = {"shallow2deep": shallow2deep_config}
833    else:
834        assert isinstance(config, dict)
835        config["shallow2deep"] = shallow2deep_config
836    return config