torch_em.model.resnet3d

  1# this file implements 3d resnets, based on the implementations from torchvision:
  2# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
  3
  4from typing import Any, Callable, List, Optional, Type, Union
  5
  6import torch
  7import torch.nn as nn
  8from torch import Tensor
  9
 10# from torchvision.models._api import WeightsEnum
 11# from torchvision.models._utils import _ovewrite_named_param
 12# from torchvision.utils import _log_api_usage_once
 13
 14
 15__all__ = [
 16    "ResNet3d",
 17    "resnet3d_18",
 18    "resnet3d_34",
 19    "resnet3d_50",
 20    "resnet3d_101",
 21    "resnet3d_152",
 22    "resnext3d_50_32x4d",
 23    "resnext3d_101_32x8d",
 24    "resnext3d_101_64x4d",
 25    "wide_resnet3d_50_2",
 26    "wide_resnet3d_101_2",
 27]
 28
 29
 30def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv3d:
 31    """3x3 convolution with padding"""
 32    return nn.Conv3d(
 33        in_planes,
 34        out_planes,
 35        kernel_size=3,
 36        stride=stride,
 37        padding=dilation,
 38        groups=groups,
 39        bias=False,
 40        dilation=dilation,
 41    )
 42
 43
 44def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
 45    """1x1 convolution"""
 46    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 47
 48
 49class BasicBlock(nn.Module):
 50    expansion: int = 1
 51
 52    def __init__(
 53        self,
 54        inplanes: int,
 55        planes: int,
 56        stride: int = 1,
 57        downsample: Optional[nn.Module] = None,
 58        groups: int = 1,
 59        base_width: int = 64,
 60        dilation: int = 1,
 61        norm_layer: Optional[Callable[..., nn.Module]] = None,
 62    ) -> None:
 63        super().__init__()
 64        if norm_layer is None:
 65            norm_layer = nn.BatchNorm3d
 66        if groups != 1 or base_width != 64:
 67            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
 68        if dilation > 1:
 69            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
 70        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
 71        self.conv1 = conv3x3(inplanes, planes, stride)
 72        self.bn1 = norm_layer(planes)
 73        self.relu = nn.ReLU(inplace=True)
 74        self.conv2 = conv3x3(planes, planes)
 75        self.bn2 = norm_layer(planes)
 76        self.downsample = downsample
 77        self.stride = stride
 78
 79    def forward(self, x: Tensor) -> Tensor:
 80        identity = x
 81
 82        out = self.conv1(x)
 83        out = self.bn1(out)
 84        out = self.relu(out)
 85
 86        out = self.conv2(out)
 87        out = self.bn2(out)
 88
 89        if self.downsample is not None:
 90            identity = self.downsample(x)
 91
 92        out += identity
 93        out = self.relu(out)
 94
 95        return out
 96
 97
 98class Bottleneck(nn.Module):
 99    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
100    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
101    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
102    # This variant is also known as ResNet V1.5 and improves accuracy according to
103    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
104
105    expansion: int = 4
106
107    def __init__(
108        self,
109        inplanes: int,
110        planes: int,
111        stride: int = 1,
112        downsample: Optional[nn.Module] = None,
113        groups: int = 1,
114        base_width: int = 64,
115        dilation: int = 1,
116        norm_layer: Optional[Callable[..., nn.Module]] = None,
117    ) -> None:
118        super().__init__()
119        if norm_layer is None:
120            norm_layer = nn.BatchNorm3d
121        width = int(planes * (base_width / 64.0)) * groups
122        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
123        self.conv1 = conv1x1(inplanes, width)
124        self.bn1 = norm_layer(width)
125        self.conv2 = conv3x3(width, width, stride, groups, dilation)
126        self.bn2 = norm_layer(width)
127        self.conv3 = conv1x1(width, planes * self.expansion)
128        self.bn3 = norm_layer(planes * self.expansion)
129        self.relu = nn.ReLU(inplace=True)
130        self.downsample = downsample
131        self.stride = stride
132
133    def forward(self, x: Tensor) -> Tensor:
134        identity = x
135
136        out = self.conv1(x)
137        out = self.bn1(out)
138        out = self.relu(out)
139
140        out = self.conv2(out)
141        out = self.bn2(out)
142        out = self.relu(out)
143
144        out = self.conv3(out)
145        out = self.bn3(out)
146
147        if self.downsample is not None:
148            identity = self.downsample(x)
149
150        out += identity
151        out = self.relu(out)
152
153        return out
154
155
156class ResNet3d(nn.Module):
157    def __init__(
158        self,
159        block: Type[Union[BasicBlock, Bottleneck]],
160        layers: List[int],
161        in_channels: int,
162        out_channels: int,
163        zero_init_residual: bool = False,
164        groups: int = 1,
165        width_per_group: int = 64,
166        replace_stride_with_dilation: Optional[List[bool]] = None,
167        norm_layer: Optional[Callable[..., nn.Module]] = None,
168        stride_conv1: bool = True,
169    ) -> None:
170        super().__init__()
171        # _log_api_usage_once(self)
172        if norm_layer is None:
173            norm_layer = nn.BatchNorm3d
174        self._norm_layer = norm_layer
175
176        self.in_channels = in_channels
177        self.out_channels = out_channels
178
179        self.inplanes = 64
180        self.dilation = 1
181        if replace_stride_with_dilation is None:
182            # each element in the tuple indicates if we should replace
183            # the 2x2 stride with a dilated convolution instead
184            replace_stride_with_dilation = [False, False, False]
185        if len(replace_stride_with_dilation) != 3:
186            raise ValueError(
187                "replace_stride_with_dilation should be None "
188                f"or a 3-element tuple, got {replace_stride_with_dilation}"
189            )
190        self.groups = groups
191        self.base_width = width_per_group
192        self.conv1 = nn.Conv3d(
193            in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False
194        )
195        self.bn1 = norm_layer(self.inplanes)
196        self.relu = nn.ReLU(inplace=True)
197        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
198        self.layer1 = self._make_layer(block, 64, layers[0])
199        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
200        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
201        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
202        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
203        self.fc = nn.Linear(512 * block.expansion, out_channels)
204
205        for m in self.modules():
206            if isinstance(m, nn.Conv3d):
207                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
208            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
209                nn.init.constant_(m.weight, 1)
210                nn.init.constant_(m.bias, 0)
211
212        # Zero-initialize the last BN in each residual branch,
213        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
214        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
215        if zero_init_residual:
216            for m in self.modules():
217                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
218                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
219                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
220                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
221
222    def _make_layer(
223        self,
224        block: Type[Union[BasicBlock, Bottleneck]],
225        planes: int,
226        blocks: int,
227        stride: int = 1,
228        dilate: bool = False,
229    ) -> nn.Sequential:
230        norm_layer = self._norm_layer
231        downsample = None
232        previous_dilation = self.dilation
233        if dilate:
234            self.dilation *= stride
235            stride = 1
236        if stride != 1 or self.inplanes != planes * block.expansion:
237            downsample = nn.Sequential(
238                conv1x1(self.inplanes, planes * block.expansion, stride),
239                norm_layer(planes * block.expansion),
240            )
241
242        layers = []
243        layers.append(
244            block(
245                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
246            )
247        )
248        self.inplanes = planes * block.expansion
249        for _ in range(1, blocks):
250            layers.append(
251                block(
252                    self.inplanes,
253                    planes,
254                    groups=self.groups,
255                    base_width=self.base_width,
256                    dilation=self.dilation,
257                    norm_layer=norm_layer,
258                )
259            )
260
261        return nn.Sequential(*layers)
262
263    def _forward_impl(self, x: Tensor) -> Tensor:
264        # See note [TorchScript super()]
265        x = self.conv1(x)
266        x = self.bn1(x)
267        x = self.relu(x)
268        x = self.maxpool(x)
269
270        x = self.layer1(x)
271        x = self.layer2(x)
272        x = self.layer3(x)
273        x = self.layer4(x)
274
275        x = self.avgpool(x)
276        x = torch.flatten(x, 1)
277        x = self.fc(x)
278
279        return x
280
281    def forward(self, x: Tensor) -> Tensor:
282        return self._forward_impl(x)
283
284
285def _resnet(
286    block: Type[Union[BasicBlock, Bottleneck]],
287    layers: List[int],
288    weights: Any,
289    progress: bool,
290    **kwargs: Any,
291) -> ResNet3d:
292    if weights is not None:
293        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
294
295    model = ResNet3d(block, layers, **kwargs)
296
297    if weights is not None:
298        model.load_state_dict(weights.get_state_dict(progress=progress))
299
300    return model
301
302
303def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
304    return _resnet(
305        BasicBlock, [2, 2, 2, 2], weights=None, progress=False,
306        in_channels=in_channels, out_channels=out_channels, **kwargs
307    )
308
309
310def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
311    return _resnet(
312        BasicBlock, [3, 4, 6, 3], weights=None, progress=False,
313        in_channels=in_channels, out_channels=out_channels, **kwargs
314    )
315
316
317def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
318    return _resnet(
319        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
320        in_channels=in_channels, out_channels=out_channels, **kwargs
321    )
322
323
324def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
325    return _resnet(
326        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
327        in_channels=in_channels, out_channels=out_channels, **kwargs
328    )
329
330
331def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
332    return _resnet(
333        Bottleneck, [3, 8, 36, 3], weights=None, progress=False,
334        in_channels=in_channels, out_channels=out_channels, **kwargs
335    )
336
337
338def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
339    _ovewrite_named_param(kwargs, "groups", 32)
340    _ovewrite_named_param(kwargs, "width_per_group", 4)
341    return _resnet(
342        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
343        in_channels=in_channels, out_channels=out_channels, **kwargs
344    )
345
346
347def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
348    _ovewrite_named_param(kwargs, "groups", 32)
349    _ovewrite_named_param(kwargs, "width_per_group", 8)
350    return _resnet(
351        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
352        in_channels=in_channels, out_channels=out_channels, **kwargs
353    )
354
355
356def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
357    _ovewrite_named_param(kwargs, "groups", 64)
358    _ovewrite_named_param(kwargs, "width_per_group", 4)
359    return _resnet(
360        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
361        in_channels=in_channels, out_channels=out_channels, **kwargs
362    )
363
364
365def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
366    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
367    return _resnet(
368        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
369        in_channels=in_channels, out_channels=out_channels, **kwargs
370    )
371
372
373def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
374    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
375    return _resnet(
376        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
377        in_channels=in_channels, out_channels=out_channels, **kwargs
378    )
class ResNet3d(torch.nn.modules.module.Module):
157class ResNet3d(nn.Module):
158    def __init__(
159        self,
160        block: Type[Union[BasicBlock, Bottleneck]],
161        layers: List[int],
162        in_channels: int,
163        out_channels: int,
164        zero_init_residual: bool = False,
165        groups: int = 1,
166        width_per_group: int = 64,
167        replace_stride_with_dilation: Optional[List[bool]] = None,
168        norm_layer: Optional[Callable[..., nn.Module]] = None,
169        stride_conv1: bool = True,
170    ) -> None:
171        super().__init__()
172        # _log_api_usage_once(self)
173        if norm_layer is None:
174            norm_layer = nn.BatchNorm3d
175        self._norm_layer = norm_layer
176
177        self.in_channels = in_channels
178        self.out_channels = out_channels
179
180        self.inplanes = 64
181        self.dilation = 1
182        if replace_stride_with_dilation is None:
183            # each element in the tuple indicates if we should replace
184            # the 2x2 stride with a dilated convolution instead
185            replace_stride_with_dilation = [False, False, False]
186        if len(replace_stride_with_dilation) != 3:
187            raise ValueError(
188                "replace_stride_with_dilation should be None "
189                f"or a 3-element tuple, got {replace_stride_with_dilation}"
190            )
191        self.groups = groups
192        self.base_width = width_per_group
193        self.conv1 = nn.Conv3d(
194            in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False
195        )
196        self.bn1 = norm_layer(self.inplanes)
197        self.relu = nn.ReLU(inplace=True)
198        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
199        self.layer1 = self._make_layer(block, 64, layers[0])
200        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
201        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
202        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
203        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
204        self.fc = nn.Linear(512 * block.expansion, out_channels)
205
206        for m in self.modules():
207            if isinstance(m, nn.Conv3d):
208                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
209            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
210                nn.init.constant_(m.weight, 1)
211                nn.init.constant_(m.bias, 0)
212
213        # Zero-initialize the last BN in each residual branch,
214        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
215        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
216        if zero_init_residual:
217            for m in self.modules():
218                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
219                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
220                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
221                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
222
223    def _make_layer(
224        self,
225        block: Type[Union[BasicBlock, Bottleneck]],
226        planes: int,
227        blocks: int,
228        stride: int = 1,
229        dilate: bool = False,
230    ) -> nn.Sequential:
231        norm_layer = self._norm_layer
232        downsample = None
233        previous_dilation = self.dilation
234        if dilate:
235            self.dilation *= stride
236            stride = 1
237        if stride != 1 or self.inplanes != planes * block.expansion:
238            downsample = nn.Sequential(
239                conv1x1(self.inplanes, planes * block.expansion, stride),
240                norm_layer(planes * block.expansion),
241            )
242
243        layers = []
244        layers.append(
245            block(
246                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
247            )
248        )
249        self.inplanes = planes * block.expansion
250        for _ in range(1, blocks):
251            layers.append(
252                block(
253                    self.inplanes,
254                    planes,
255                    groups=self.groups,
256                    base_width=self.base_width,
257                    dilation=self.dilation,
258                    norm_layer=norm_layer,
259                )
260            )
261
262        return nn.Sequential(*layers)
263
264    def _forward_impl(self, x: Tensor) -> Tensor:
265        # See note [TorchScript super()]
266        x = self.conv1(x)
267        x = self.bn1(x)
268        x = self.relu(x)
269        x = self.maxpool(x)
270
271        x = self.layer1(x)
272        x = self.layer2(x)
273        x = self.layer3(x)
274        x = self.layer4(x)
275
276        x = self.avgpool(x)
277        x = torch.flatten(x, 1)
278        x = self.fc(x)
279
280        return x
281
282    def forward(self, x: Tensor) -> Tensor:
283        return self._forward_impl(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResNet3d( block: Type[Union[torch_em.model.resnet3d.BasicBlock, torch_em.model.resnet3d.Bottleneck]], layers: List[int], in_channels: int, out_channels: int, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., torch.nn.modules.module.Module]] = None, stride_conv1: bool = True)
158    def __init__(
159        self,
160        block: Type[Union[BasicBlock, Bottleneck]],
161        layers: List[int],
162        in_channels: int,
163        out_channels: int,
164        zero_init_residual: bool = False,
165        groups: int = 1,
166        width_per_group: int = 64,
167        replace_stride_with_dilation: Optional[List[bool]] = None,
168        norm_layer: Optional[Callable[..., nn.Module]] = None,
169        stride_conv1: bool = True,
170    ) -> None:
171        super().__init__()
172        # _log_api_usage_once(self)
173        if norm_layer is None:
174            norm_layer = nn.BatchNorm3d
175        self._norm_layer = norm_layer
176
177        self.in_channels = in_channels
178        self.out_channels = out_channels
179
180        self.inplanes = 64
181        self.dilation = 1
182        if replace_stride_with_dilation is None:
183            # each element in the tuple indicates if we should replace
184            # the 2x2 stride with a dilated convolution instead
185            replace_stride_with_dilation = [False, False, False]
186        if len(replace_stride_with_dilation) != 3:
187            raise ValueError(
188                "replace_stride_with_dilation should be None "
189                f"or a 3-element tuple, got {replace_stride_with_dilation}"
190            )
191        self.groups = groups
192        self.base_width = width_per_group
193        self.conv1 = nn.Conv3d(
194            in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False
195        )
196        self.bn1 = norm_layer(self.inplanes)
197        self.relu = nn.ReLU(inplace=True)
198        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
199        self.layer1 = self._make_layer(block, 64, layers[0])
200        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
201        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
202        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
203        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
204        self.fc = nn.Linear(512 * block.expansion, out_channels)
205
206        for m in self.modules():
207            if isinstance(m, nn.Conv3d):
208                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
209            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
210                nn.init.constant_(m.weight, 1)
211                nn.init.constant_(m.bias, 0)
212
213        # Zero-initialize the last BN in each residual branch,
214        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
215        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
216        if zero_init_residual:
217            for m in self.modules():
218                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
219                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
220                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
221                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

in_channels
out_channels
inplanes
dilation
groups
base_width
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc
def forward(self, x: torch.Tensor) -> torch.Tensor:
282    def forward(self, x: Tensor) -> Tensor:
283        return self._forward_impl(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def resnet3d_18( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
304def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
305    return _resnet(
306        BasicBlock, [2, 2, 2, 2], weights=None, progress=False,
307        in_channels=in_channels, out_channels=out_channels, **kwargs
308    )
def resnet3d_34( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
311def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
312    return _resnet(
313        BasicBlock, [3, 4, 6, 3], weights=None, progress=False,
314        in_channels=in_channels, out_channels=out_channels, **kwargs
315    )
def resnet3d_50( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
318def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
319    return _resnet(
320        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
321        in_channels=in_channels, out_channels=out_channels, **kwargs
322    )
def resnet3d_101( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
325def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
326    return _resnet(
327        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
328        in_channels=in_channels, out_channels=out_channels, **kwargs
329    )
def resnet3d_152( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
332def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
333    return _resnet(
334        Bottleneck, [3, 8, 36, 3], weights=None, progress=False,
335        in_channels=in_channels, out_channels=out_channels, **kwargs
336    )
def resnext3d_50_32x4d( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
339def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
340    _ovewrite_named_param(kwargs, "groups", 32)
341    _ovewrite_named_param(kwargs, "width_per_group", 4)
342    return _resnet(
343        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
344        in_channels=in_channels, out_channels=out_channels, **kwargs
345    )
def resnext3d_101_32x8d( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
348def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
349    _ovewrite_named_param(kwargs, "groups", 32)
350    _ovewrite_named_param(kwargs, "width_per_group", 8)
351    return _resnet(
352        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
353        in_channels=in_channels, out_channels=out_channels, **kwargs
354    )
def resnext3d_101_64x4d( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
357def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
358    _ovewrite_named_param(kwargs, "groups", 64)
359    _ovewrite_named_param(kwargs, "width_per_group", 4)
360    return _resnet(
361        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
362        in_channels=in_channels, out_channels=out_channels, **kwargs
363    )
def wide_resnet3d_50_2( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
366def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
367    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
368    return _resnet(
369        Bottleneck, [3, 4, 6, 3], weights=None, progress=False,
370        in_channels=in_channels, out_channels=out_channels, **kwargs
371    )
def wide_resnet3d_101_2( in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
374def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d:
375    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
376    return _resnet(
377        Bottleneck, [3, 4, 23, 3], weights=None, progress=False,
378        in_channels=in_channels, out_channels=out_channels, **kwargs
379    )