torch_em.model.resnet3d
1# this file implements 3d resnets, based on the implementations from torchvision: 2# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 3 4from typing import Any, Callable, List, Optional, Type, Union 5 6import torch 7import torch.nn as nn 8from torch import Tensor 9 10# from torchvision.models._api import WeightsEnum 11# from torchvision.models._utils import _ovewrite_named_param 12# from torchvision.utils import _log_api_usage_once 13 14 15__all__ = [ 16 "ResNet3d", 17 "resnet3d_18", 18 "resnet3d_34", 19 "resnet3d_50", 20 "resnet3d_101", 21 "resnet3d_152", 22 "resnext3d_50_32x4d", 23 "resnext3d_101_32x8d", 24 "resnext3d_101_64x4d", 25 "wide_resnet3d_50_2", 26 "wide_resnet3d_101_2", 27] 28 29 30def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv3d: 31 """3x3 convolution with padding""" 32 return nn.Conv3d( 33 in_planes, 34 out_planes, 35 kernel_size=3, 36 stride=stride, 37 padding=dilation, 38 groups=groups, 39 bias=False, 40 dilation=dilation, 41 ) 42 43 44def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d: 45 """1x1 convolution""" 46 return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 47 48 49class BasicBlock(nn.Module): 50 expansion: int = 1 51 52 def __init__( 53 self, 54 inplanes: int, 55 planes: int, 56 stride: int = 1, 57 downsample: Optional[nn.Module] = None, 58 groups: int = 1, 59 base_width: int = 64, 60 dilation: int = 1, 61 norm_layer: Optional[Callable[..., nn.Module]] = None, 62 ) -> None: 63 super().__init__() 64 if norm_layer is None: 65 norm_layer = nn.BatchNorm3d 66 if groups != 1 or base_width != 64: 67 raise ValueError("BasicBlock only supports groups=1 and base_width=64") 68 if dilation > 1: 69 raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 70 # Both self.conv1 and self.downsample layers downsample the input when stride != 1 71 self.conv1 = conv3x3(inplanes, planes, stride) 72 self.bn1 = norm_layer(planes) 73 self.relu = nn.ReLU(inplace=True) 74 self.conv2 = conv3x3(planes, planes) 75 self.bn2 = norm_layer(planes) 76 self.downsample = downsample 77 self.stride = stride 78 79 def forward(self, x: Tensor) -> Tensor: 80 identity = x 81 82 out = self.conv1(x) 83 out = self.bn1(out) 84 out = self.relu(out) 85 86 out = self.conv2(out) 87 out = self.bn2(out) 88 89 if self.downsample is not None: 90 identity = self.downsample(x) 91 92 out += identity 93 out = self.relu(out) 94 95 return out 96 97 98class Bottleneck(nn.Module): 99 # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 100 # while original implementation places the stride at the first 1x1 convolution(self.conv1) 101 # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 102 # This variant is also known as ResNet V1.5 and improves accuracy according to 103 # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 104 105 expansion: int = 4 106 107 def __init__( 108 self, 109 inplanes: int, 110 planes: int, 111 stride: int = 1, 112 downsample: Optional[nn.Module] = None, 113 groups: int = 1, 114 base_width: int = 64, 115 dilation: int = 1, 116 norm_layer: Optional[Callable[..., nn.Module]] = None, 117 ) -> None: 118 super().__init__() 119 if norm_layer is None: 120 norm_layer = nn.BatchNorm3d 121 width = int(planes * (base_width / 64.0)) * groups 122 # Both self.conv2 and self.downsample layers downsample the input when stride != 1 123 self.conv1 = conv1x1(inplanes, width) 124 self.bn1 = norm_layer(width) 125 self.conv2 = conv3x3(width, width, stride, groups, dilation) 126 self.bn2 = norm_layer(width) 127 self.conv3 = conv1x1(width, planes * self.expansion) 128 self.bn3 = norm_layer(planes * self.expansion) 129 self.relu = nn.ReLU(inplace=True) 130 self.downsample = downsample 131 self.stride = stride 132 133 def forward(self, x: Tensor) -> Tensor: 134 identity = x 135 136 out = self.conv1(x) 137 out = self.bn1(out) 138 out = self.relu(out) 139 140 out = self.conv2(out) 141 out = self.bn2(out) 142 out = self.relu(out) 143 144 out = self.conv3(out) 145 out = self.bn3(out) 146 147 if self.downsample is not None: 148 identity = self.downsample(x) 149 150 out += identity 151 out = self.relu(out) 152 153 return out 154 155 156class ResNet3d(nn.Module): 157 def __init__( 158 self, 159 block: Type[Union[BasicBlock, Bottleneck]], 160 layers: List[int], 161 in_channels: int, 162 out_channels: int, 163 zero_init_residual: bool = False, 164 groups: int = 1, 165 width_per_group: int = 64, 166 replace_stride_with_dilation: Optional[List[bool]] = None, 167 norm_layer: Optional[Callable[..., nn.Module]] = None, 168 stride_conv1: bool = True, 169 ) -> None: 170 super().__init__() 171 # _log_api_usage_once(self) 172 if norm_layer is None: 173 norm_layer = nn.BatchNorm3d 174 self._norm_layer = norm_layer 175 176 self.in_channels = in_channels 177 self.out_channels = out_channels 178 179 self.inplanes = 64 180 self.dilation = 1 181 if replace_stride_with_dilation is None: 182 # each element in the tuple indicates if we should replace 183 # the 2x2 stride with a dilated convolution instead 184 replace_stride_with_dilation = [False, False, False] 185 if len(replace_stride_with_dilation) != 3: 186 raise ValueError( 187 "replace_stride_with_dilation should be None " 188 f"or a 3-element tuple, got {replace_stride_with_dilation}" 189 ) 190 self.groups = groups 191 self.base_width = width_per_group 192 self.conv1 = nn.Conv3d( 193 in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False 194 ) 195 self.bn1 = norm_layer(self.inplanes) 196 self.relu = nn.ReLU(inplace=True) 197 self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 198 self.layer1 = self._make_layer(block, 64, layers[0]) 199 self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 200 self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 201 self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 202 self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 203 self.fc = nn.Linear(512 * block.expansion, out_channels) 204 205 for m in self.modules(): 206 if isinstance(m, nn.Conv3d): 207 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 208 elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 209 nn.init.constant_(m.weight, 1) 210 nn.init.constant_(m.bias, 0) 211 212 # Zero-initialize the last BN in each residual branch, 213 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 214 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 215 if zero_init_residual: 216 for m in self.modules(): 217 if isinstance(m, Bottleneck) and m.bn3.weight is not None: 218 nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 219 elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 220 nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 221 222 def _make_layer( 223 self, 224 block: Type[Union[BasicBlock, Bottleneck]], 225 planes: int, 226 blocks: int, 227 stride: int = 1, 228 dilate: bool = False, 229 ) -> nn.Sequential: 230 norm_layer = self._norm_layer 231 downsample = None 232 previous_dilation = self.dilation 233 if dilate: 234 self.dilation *= stride 235 stride = 1 236 if stride != 1 or self.inplanes != planes * block.expansion: 237 downsample = nn.Sequential( 238 conv1x1(self.inplanes, planes * block.expansion, stride), 239 norm_layer(planes * block.expansion), 240 ) 241 242 layers = [] 243 layers.append( 244 block( 245 self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 246 ) 247 ) 248 self.inplanes = planes * block.expansion 249 for _ in range(1, blocks): 250 layers.append( 251 block( 252 self.inplanes, 253 planes, 254 groups=self.groups, 255 base_width=self.base_width, 256 dilation=self.dilation, 257 norm_layer=norm_layer, 258 ) 259 ) 260 261 return nn.Sequential(*layers) 262 263 def _forward_impl(self, x: Tensor) -> Tensor: 264 # See note [TorchScript super()] 265 x = self.conv1(x) 266 x = self.bn1(x) 267 x = self.relu(x) 268 x = self.maxpool(x) 269 270 x = self.layer1(x) 271 x = self.layer2(x) 272 x = self.layer3(x) 273 x = self.layer4(x) 274 275 x = self.avgpool(x) 276 x = torch.flatten(x, 1) 277 x = self.fc(x) 278 279 return x 280 281 def forward(self, x: Tensor) -> Tensor: 282 return self._forward_impl(x) 283 284 285def _resnet( 286 block: Type[Union[BasicBlock, Bottleneck]], 287 layers: List[int], 288 weights: Any, 289 progress: bool, 290 **kwargs: Any, 291) -> ResNet3d: 292 if weights is not None: 293 _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 294 295 model = ResNet3d(block, layers, **kwargs) 296 297 if weights is not None: 298 model.load_state_dict(weights.get_state_dict(progress=progress)) 299 300 return model 301 302 303def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 304 return _resnet( 305 BasicBlock, [2, 2, 2, 2], weights=None, progress=False, 306 in_channels=in_channels, out_channels=out_channels, **kwargs 307 ) 308 309 310def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 311 return _resnet( 312 BasicBlock, [3, 4, 6, 3], weights=None, progress=False, 313 in_channels=in_channels, out_channels=out_channels, **kwargs 314 ) 315 316 317def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 318 return _resnet( 319 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 320 in_channels=in_channels, out_channels=out_channels, **kwargs 321 ) 322 323 324def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 325 return _resnet( 326 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 327 in_channels=in_channels, out_channels=out_channels, **kwargs 328 ) 329 330 331def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 332 return _resnet( 333 Bottleneck, [3, 8, 36, 3], weights=None, progress=False, 334 in_channels=in_channels, out_channels=out_channels, **kwargs 335 ) 336 337 338def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 339 _ovewrite_named_param(kwargs, "groups", 32) 340 _ovewrite_named_param(kwargs, "width_per_group", 4) 341 return _resnet( 342 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 343 in_channels=in_channels, out_channels=out_channels, **kwargs 344 ) 345 346 347def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 348 _ovewrite_named_param(kwargs, "groups", 32) 349 _ovewrite_named_param(kwargs, "width_per_group", 8) 350 return _resnet( 351 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 352 in_channels=in_channels, out_channels=out_channels, **kwargs 353 ) 354 355 356def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 357 _ovewrite_named_param(kwargs, "groups", 64) 358 _ovewrite_named_param(kwargs, "width_per_group", 4) 359 return _resnet( 360 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 361 in_channels=in_channels, out_channels=out_channels, **kwargs 362 ) 363 364 365def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 366 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 367 return _resnet( 368 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 369 in_channels=in_channels, out_channels=out_channels, **kwargs 370 ) 371 372 373def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 374 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 375 return _resnet( 376 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 377 in_channels=in_channels, out_channels=out_channels, **kwargs 378 )
157class ResNet3d(nn.Module): 158 def __init__( 159 self, 160 block: Type[Union[BasicBlock, Bottleneck]], 161 layers: List[int], 162 in_channels: int, 163 out_channels: int, 164 zero_init_residual: bool = False, 165 groups: int = 1, 166 width_per_group: int = 64, 167 replace_stride_with_dilation: Optional[List[bool]] = None, 168 norm_layer: Optional[Callable[..., nn.Module]] = None, 169 stride_conv1: bool = True, 170 ) -> None: 171 super().__init__() 172 # _log_api_usage_once(self) 173 if norm_layer is None: 174 norm_layer = nn.BatchNorm3d 175 self._norm_layer = norm_layer 176 177 self.in_channels = in_channels 178 self.out_channels = out_channels 179 180 self.inplanes = 64 181 self.dilation = 1 182 if replace_stride_with_dilation is None: 183 # each element in the tuple indicates if we should replace 184 # the 2x2 stride with a dilated convolution instead 185 replace_stride_with_dilation = [False, False, False] 186 if len(replace_stride_with_dilation) != 3: 187 raise ValueError( 188 "replace_stride_with_dilation should be None " 189 f"or a 3-element tuple, got {replace_stride_with_dilation}" 190 ) 191 self.groups = groups 192 self.base_width = width_per_group 193 self.conv1 = nn.Conv3d( 194 in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False 195 ) 196 self.bn1 = norm_layer(self.inplanes) 197 self.relu = nn.ReLU(inplace=True) 198 self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 199 self.layer1 = self._make_layer(block, 64, layers[0]) 200 self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 201 self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 202 self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 203 self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 204 self.fc = nn.Linear(512 * block.expansion, out_channels) 205 206 for m in self.modules(): 207 if isinstance(m, nn.Conv3d): 208 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 209 elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 210 nn.init.constant_(m.weight, 1) 211 nn.init.constant_(m.bias, 0) 212 213 # Zero-initialize the last BN in each residual branch, 214 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 215 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 216 if zero_init_residual: 217 for m in self.modules(): 218 if isinstance(m, Bottleneck) and m.bn3.weight is not None: 219 nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 220 elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 221 nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 222 223 def _make_layer( 224 self, 225 block: Type[Union[BasicBlock, Bottleneck]], 226 planes: int, 227 blocks: int, 228 stride: int = 1, 229 dilate: bool = False, 230 ) -> nn.Sequential: 231 norm_layer = self._norm_layer 232 downsample = None 233 previous_dilation = self.dilation 234 if dilate: 235 self.dilation *= stride 236 stride = 1 237 if stride != 1 or self.inplanes != planes * block.expansion: 238 downsample = nn.Sequential( 239 conv1x1(self.inplanes, planes * block.expansion, stride), 240 norm_layer(planes * block.expansion), 241 ) 242 243 layers = [] 244 layers.append( 245 block( 246 self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 247 ) 248 ) 249 self.inplanes = planes * block.expansion 250 for _ in range(1, blocks): 251 layers.append( 252 block( 253 self.inplanes, 254 planes, 255 groups=self.groups, 256 base_width=self.base_width, 257 dilation=self.dilation, 258 norm_layer=norm_layer, 259 ) 260 ) 261 262 return nn.Sequential(*layers) 263 264 def _forward_impl(self, x: Tensor) -> Tensor: 265 # See note [TorchScript super()] 266 x = self.conv1(x) 267 x = self.bn1(x) 268 x = self.relu(x) 269 x = self.maxpool(x) 270 271 x = self.layer1(x) 272 x = self.layer2(x) 273 x = self.layer3(x) 274 x = self.layer4(x) 275 276 x = self.avgpool(x) 277 x = torch.flatten(x, 1) 278 x = self.fc(x) 279 280 return x 281 282 def forward(self, x: Tensor) -> Tensor: 283 return self._forward_impl(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
158 def __init__( 159 self, 160 block: Type[Union[BasicBlock, Bottleneck]], 161 layers: List[int], 162 in_channels: int, 163 out_channels: int, 164 zero_init_residual: bool = False, 165 groups: int = 1, 166 width_per_group: int = 64, 167 replace_stride_with_dilation: Optional[List[bool]] = None, 168 norm_layer: Optional[Callable[..., nn.Module]] = None, 169 stride_conv1: bool = True, 170 ) -> None: 171 super().__init__() 172 # _log_api_usage_once(self) 173 if norm_layer is None: 174 norm_layer = nn.BatchNorm3d 175 self._norm_layer = norm_layer 176 177 self.in_channels = in_channels 178 self.out_channels = out_channels 179 180 self.inplanes = 64 181 self.dilation = 1 182 if replace_stride_with_dilation is None: 183 # each element in the tuple indicates if we should replace 184 # the 2x2 stride with a dilated convolution instead 185 replace_stride_with_dilation = [False, False, False] 186 if len(replace_stride_with_dilation) != 3: 187 raise ValueError( 188 "replace_stride_with_dilation should be None " 189 f"or a 3-element tuple, got {replace_stride_with_dilation}" 190 ) 191 self.groups = groups 192 self.base_width = width_per_group 193 self.conv1 = nn.Conv3d( 194 in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False 195 ) 196 self.bn1 = norm_layer(self.inplanes) 197 self.relu = nn.ReLU(inplace=True) 198 self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 199 self.layer1 = self._make_layer(block, 64, layers[0]) 200 self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 201 self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 202 self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 203 self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 204 self.fc = nn.Linear(512 * block.expansion, out_channels) 205 206 for m in self.modules(): 207 if isinstance(m, nn.Conv3d): 208 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 209 elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 210 nn.init.constant_(m.weight, 1) 211 nn.init.constant_(m.bias, 0) 212 213 # Zero-initialize the last BN in each residual branch, 214 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 215 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 216 if zero_init_residual: 217 for m in self.modules(): 218 if isinstance(m, Bottleneck) and m.bn3.weight is not None: 219 nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 220 elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 221 nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
339def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 340 _ovewrite_named_param(kwargs, "groups", 32) 341 _ovewrite_named_param(kwargs, "width_per_group", 4) 342 return _resnet( 343 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 344 in_channels=in_channels, out_channels=out_channels, **kwargs 345 )
348def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 349 _ovewrite_named_param(kwargs, "groups", 32) 350 _ovewrite_named_param(kwargs, "width_per_group", 8) 351 return _resnet( 352 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 353 in_channels=in_channels, out_channels=out_channels, **kwargs 354 )
357def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 358 _ovewrite_named_param(kwargs, "groups", 64) 359 _ovewrite_named_param(kwargs, "width_per_group", 4) 360 return _resnet( 361 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 362 in_channels=in_channels, out_channels=out_channels, **kwargs 363 )
366def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 367 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 368 return _resnet( 369 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 370 in_channels=in_channels, out_channels=out_channels, **kwargs 371 )
374def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 375 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 376 return _resnet( 377 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 378 in_channels=in_channels, out_channels=out_channels, **kwargs 379 )