torch_em.util.prediction

  1from concurrent import futures
  2from copy import deepcopy
  3
  4import nifty.tools as nt
  5import numpy as np
  6import torch
  7from tqdm import tqdm
  8
  9from ..transform.raw import standardize
 10
 11
 12def predict_with_padding(model, input_, min_divisible, device, with_channels=False, prediction_function=None):
 13    """Run prediction with padding for a model that can only deal with
 14    inputs divisible by specific factors.
 15
 16    Arguments:
 17        model [torch.nn.Module]: the model
 18        input_ [np.ndarray]: the input ()
 19        min_divisible [tuple]: the divisibe shape factors
 20            (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels)
 21        device [str, torch.device]: the device of the model
 22        with_channels [bool]: Whether the input data contains channels (default: False)
 23        prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures
 24            (default: None)
 25    Returns:
 26        np.ndarray: the ouptut of the model
 27    """
 28    if with_channels:
 29        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
 30        min_divisible_ = (1,) + min_divisible
 31    else:
 32        assert len(min_divisible) == input_.ndim
 33        min_divisible_ = min_divisible
 34
 35    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
 36        pad_width = tuple(
 37            (0, 0 if sh % md == 0 else md - sh % md)
 38            for sh, md in zip(input_.shape, min_divisible_)
 39        )
 40        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
 41        input_ = np.pad(input_, pad_width, mode="reflect")
 42    else:
 43        crop_padding = None
 44
 45    ndim = input_.ndim
 46    ndim_model = 1 + ndim if with_channels else 2 + ndim
 47
 48    expand_dim = (None,) * (ndim_model - ndim)
 49    with torch.no_grad():
 50        model_input = torch.from_numpy(input_[expand_dim]).to(device)
 51        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
 52        output = output.cpu().numpy()
 53
 54    if crop_padding is not None:
 55        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
 56        output = output[crop_padding]
 57
 58    return output
 59
 60
 61def _load_block(input_, offset, block_shape, halo, padding_mode="reflect", with_channels=False):
 62    shape = input_.shape
 63    if with_channels:
 64        shape = shape[1:]
 65
 66    starts = [off - ha for off, ha in zip(offset, halo)]
 67    stops = [off + bs + ha for off, bs, ha in zip(offset, block_shape, halo)]
 68
 69    # we pad the input volume if necessary
 70    pad_left = None
 71    pad_right = None
 72
 73    # check for padding to the left
 74    if any(start < 0 for start in starts):
 75        pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)
 76        starts = [max(0, start) for start in starts]
 77
 78    # check for padding to the right
 79    if any(stop > shape[i] for i, stop in enumerate(stops)):
 80        pad_right = tuple(stop - shape[i] if stop > shape[i] else 0 for i, stop in enumerate(stops))
 81        stops = [min(shape[i], stop) for i, stop in enumerate(stops)]
 82
 83    bb = tuple(slice(start, stop) for start, stop in zip(starts, stops))
 84    if with_channels:
 85        data = input_[(slice(None),) + bb]
 86    else:
 87        data = input_[bb]
 88
 89    ndim = len(shape)
 90    # pad if necessary
 91    if pad_left is not None or pad_right is not None:
 92        pad_left = (0,) * ndim if pad_left is None else pad_left
 93        pad_right = (0,) * ndim if pad_right is None else pad_right
 94        pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))
 95        if with_channels:
 96            pad_width = ((0, 0),) + pad_width
 97        data = np.pad(data, pad_width, mode=padding_mode)
 98
 99        # extend the bounding box for downstream
100        bb = tuple(
101            slice(b.start - pl, b.stop + pr)
102            for b, pl, pr in zip(bb, pad_left, pad_right)
103        )
104
105    return data, bb
106
107
108# TODO half precision prediction
109def predict_with_halo(
110    input_,
111    model,
112    gpu_ids,
113    block_shape,
114    halo,
115    output=None,
116    preprocess=standardize,
117    postprocess=None,
118    with_channels=False,
119    skip_block=None,
120    mask=None,
121    disable_tqdm=False,
122    tqdm_desc="predict with halo",
123    prediction_function=None,
124    roi=None,
125):
126    """ Run block-wise network prediction with halo.
127
128    Arguments:
129        input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
130        model [nn.Module] - the network
131        gpu_ids [list[int or string]] - list of gpus id used for prediction
132        block_shape [tuple] - shape of inner block used for prediction
133        halo [tuple] - shape of halo used for prediction
134        output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed.
135            Instead of a single output, this can also be a list of outputs and the corresponding channels.
136            (default: None)
137        preprocess [callable] - function to preprocess input data before passing it to the network.
138            (default: standardize)
139        postprocess [callable] - function to postprocess the network predictions (default: None)
140        with_channels [bool] - whether the input has a channel axis (default: False)
141        skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None)
142        mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None)
143        disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times)
144        tqdm_desc [str] - description shown by the tqdm output
145        prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures.
146            (default: None)
147        roi [tuple[slice]] - a region of interest for which to run prediction. (default: None)
148    """
149    devices = [torch.device(gpu) for gpu in gpu_ids]
150    models = [
151        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
152        for device in devices
153    ]
154
155    n_workers = len(gpu_ids)
156    shape = input_.shape
157    if with_channels:
158        shape = shape[1:]
159    ndim = len(shape)
160    assert len(block_shape) == len(halo) == ndim
161
162    if roi is None:
163        blocking = nt.blocking([0] * ndim, shape, block_shape)
164    else:
165        assert len(roi) == ndim
166        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
167        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)]
168        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
169
170    if output is None:
171        n_out = models[0][0].out_channels
172        output = np.zeros((n_out,) + shape, dtype="float32")
173
174    def predict_block(block_id):
175        worker_id = block_id % n_workers
176        net, device = models[worker_id]
177
178        with torch.no_grad():
179            block = blocking.getBlock(block_id)
180            offset = [beg for beg in block.begin]
181            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
182
183            if mask is not None:
184                mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
185                mask_block = mask_block[inner_bb]
186                if mask_block.sum() == 0:
187                    return
188
189            inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
190
191            if skip_block is not None and skip_block(inp):
192                return
193
194            if preprocess is not None:
195                inp = preprocess(inp)
196
197            # add (channel) and batch axis
198            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
199            inp = torch.from_numpy(inp[expand_dims]).to(device)
200
201            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
202
203            # allow for list of tensors
204            try:
205                prediction = prediction.cpu().numpy().squeeze(0)
206            except AttributeError:
207                prediction = prediction[0]
208                prediction = prediction.cpu().numpy().squeeze(0)
209
210            if postprocess is not None:
211                prediction = postprocess(prediction)
212
213            if prediction.ndim == ndim + 1:
214                inner_bb = (slice(None),) + inner_bb
215            prediction = prediction[inner_bb]
216
217            if mask is not None:
218                if prediction.ndim == ndim + 1:
219                    mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0)
220                prediction[~mask_block] = 0
221
222            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
223            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
224                for out, channel_slice in output:
225                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
226                    out[this_bb] = prediction[channel_slice]
227            else:  # we only have a single output array
228                if output.ndim == ndim + 1:
229                    bb = (slice(None),) + bb
230                output[bb] = prediction
231
232    n_blocks = blocking.numberOfBlocks
233    with futures.ThreadPoolExecutor(n_workers) as tp:
234        list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc))
235
236    return output
def predict_with_padding( model, input_, min_divisible, device, with_channels=False, prediction_function=None):
13def predict_with_padding(model, input_, min_divisible, device, with_channels=False, prediction_function=None):
14    """Run prediction with padding for a model that can only deal with
15    inputs divisible by specific factors.
16
17    Arguments:
18        model [torch.nn.Module]: the model
19        input_ [np.ndarray]: the input ()
20        min_divisible [tuple]: the divisibe shape factors
21            (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels)
22        device [str, torch.device]: the device of the model
23        with_channels [bool]: Whether the input data contains channels (default: False)
24        prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures
25            (default: None)
26    Returns:
27        np.ndarray: the ouptut of the model
28    """
29    if with_channels:
30        assert len(min_divisible) + 1 == input_.ndim, f"{min_divisible}, {input_.ndim}"
31        min_divisible_ = (1,) + min_divisible
32    else:
33        assert len(min_divisible) == input_.ndim
34        min_divisible_ = min_divisible
35
36    if any(sh % md != 0 for sh, md in zip(input_.shape, min_divisible_)):
37        pad_width = tuple(
38            (0, 0 if sh % md == 0 else md - sh % md)
39            for sh, md in zip(input_.shape, min_divisible_)
40        )
41        crop_padding = tuple(slice(0, sh) for sh in input_.shape)
42        input_ = np.pad(input_, pad_width, mode="reflect")
43    else:
44        crop_padding = None
45
46    ndim = input_.ndim
47    ndim_model = 1 + ndim if with_channels else 2 + ndim
48
49    expand_dim = (None,) * (ndim_model - ndim)
50    with torch.no_grad():
51        model_input = torch.from_numpy(input_[expand_dim]).to(device)
52        output = model(model_input) if prediction_function is None else prediction_function(model, model_input)
53        output = output.cpu().numpy()
54
55    if crop_padding is not None:
56        crop_padding = (slice(None),) * (output.ndim - len(crop_padding)) + crop_padding
57        output = output[crop_padding]
58
59    return output

Run prediction with padding for a model that can only deal with inputs divisible by specific factors.

Arguments:
  • model [torch.nn.Module]: the model
  • input_ [np.ndarray]: the input ()
  • min_divisible [tuple]: the divisibe shape factors (e.g. (16, 16) for a model that needs inputs divisible by at least 16 pixels)
  • device [str, torch.device]: the device of the model
  • with_channels [bool]: Whether the input data contains channels (default: False)
  • prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures (default: None)
Returns:

np.ndarray: the ouptut of the model

def predict_with_halo( input_, model, gpu_ids, block_shape, halo, output=None, preprocess=<function standardize>, postprocess=None, with_channels=False, skip_block=None, mask=None, disable_tqdm=False, tqdm_desc='predict with halo', prediction_function=None, roi=None):
110def predict_with_halo(
111    input_,
112    model,
113    gpu_ids,
114    block_shape,
115    halo,
116    output=None,
117    preprocess=standardize,
118    postprocess=None,
119    with_channels=False,
120    skip_block=None,
121    mask=None,
122    disable_tqdm=False,
123    tqdm_desc="predict with halo",
124    prediction_function=None,
125    roi=None,
126):
127    """ Run block-wise network prediction with halo.
128
129    Arguments:
130        input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
131        model [nn.Module] - the network
132        gpu_ids [list[int or string]] - list of gpus id used for prediction
133        block_shape [tuple] - shape of inner block used for prediction
134        halo [tuple] - shape of halo used for prediction
135        output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed.
136            Instead of a single output, this can also be a list of outputs and the corresponding channels.
137            (default: None)
138        preprocess [callable] - function to preprocess input data before passing it to the network.
139            (default: standardize)
140        postprocess [callable] - function to postprocess the network predictions (default: None)
141        with_channels [bool] - whether the input has a channel axis (default: False)
142        skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None)
143        mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None)
144        disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times)
145        tqdm_desc [str] - description shown by the tqdm output
146        prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures.
147            (default: None)
148        roi [tuple[slice]] - a region of interest for which to run prediction. (default: None)
149    """
150    devices = [torch.device(gpu) for gpu in gpu_ids]
151    models = [
152        (model if next(model.parameters()).device == device else deepcopy(model).to(device), device)
153        for device in devices
154    ]
155
156    n_workers = len(gpu_ids)
157    shape = input_.shape
158    if with_channels:
159        shape = shape[1:]
160    ndim = len(shape)
161    assert len(block_shape) == len(halo) == ndim
162
163    if roi is None:
164        blocking = nt.blocking([0] * ndim, shape, block_shape)
165    else:
166        assert len(roi) == ndim
167        blocking_start = [0 if ro.start is None else ro.start for ro in roi]
168        blocking_stop = [sh if ro.stop is None else ro.stop for ro, sh in zip(roi, shape)]
169        blocking = nt.blocking(blocking_start, blocking_stop, block_shape)
170
171    if output is None:
172        n_out = models[0][0].out_channels
173        output = np.zeros((n_out,) + shape, dtype="float32")
174
175    def predict_block(block_id):
176        worker_id = block_id % n_workers
177        net, device = models[worker_id]
178
179        with torch.no_grad():
180            block = blocking.getBlock(block_id)
181            offset = [beg for beg in block.begin]
182            inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape))
183
184            if mask is not None:
185                mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
186                mask_block = mask_block[inner_bb]
187                if mask_block.sum() == 0:
188                    return
189
190            inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels)
191
192            if skip_block is not None and skip_block(inp):
193                return
194
195            if preprocess is not None:
196                inp = preprocess(inp)
197
198            # add (channel) and batch axis
199            expand_dims = np.s_[None] if with_channels else np.s_[None, None]
200            inp = torch.from_numpy(inp[expand_dims]).to(device)
201
202            prediction = net(inp) if prediction_function is None else prediction_function(net, inp)
203
204            # allow for list of tensors
205            try:
206                prediction = prediction.cpu().numpy().squeeze(0)
207            except AttributeError:
208                prediction = prediction[0]
209                prediction = prediction.cpu().numpy().squeeze(0)
210
211            if postprocess is not None:
212                prediction = postprocess(prediction)
213
214            if prediction.ndim == ndim + 1:
215                inner_bb = (slice(None),) + inner_bb
216            prediction = prediction[inner_bb]
217
218            if mask is not None:
219                if prediction.ndim == ndim + 1:
220                    mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0)
221                prediction[~mask_block] = 0
222
223            bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
224            if isinstance(output, list):  # we have multiple outputs and split the prediction channels
225                for out, channel_slice in output:
226                    this_bb = bb if out.ndim == ndim else (slice(None),) + bb
227                    out[this_bb] = prediction[channel_slice]
228            else:  # we only have a single output array
229                if output.ndim == ndim + 1:
230                    bb = (slice(None),) + bb
231                output[bb] = prediction
232
233    n_blocks = blocking.numberOfBlocks
234    with futures.ThreadPoolExecutor(n_workers) as tp:
235        list(tqdm(tp.map(predict_block, range(n_blocks)), total=n_blocks, disable=disable_tqdm, desc=tqdm_desc))
236
237    return output

Run block-wise network prediction with halo.

Arguments:
  • input_ [arraylike] - the input data, can be a numpy array, a hdf5/zarr/z5py dataset or similar
  • model [nn.Module] - the network
  • gpu_ids [list[int or string]] - list of gpus id used for prediction
  • block_shape [tuple] - shape of inner block used for prediction
  • halo [tuple] - shape of halo used for prediction
  • output [arraylike or list[tuple[arraylike, slice]]] - output data, will be allocated if None is passed. Instead of a single output, this can also be a list of outputs and the corresponding channels. (default: None)
  • preprocess [callable] - function to preprocess input data before passing it to the network. (default: standardize)
  • postprocess [callable] - function to postprocess the network predictions (default: None)
  • with_channels [bool] - whether the input has a channel axis (default: False)
  • skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None)
  • mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None)
  • disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times)
  • tqdm_desc [str] - description shown by the tqdm output
  • prediction_function [callable] - a wrapper function for prediction to enable custom prediction procedures. (default: None)
  • roi [tuple[slice]] - a region of interest for which to run prediction. (default: None)