torch_em.model.resnet3d
1# This file implements 3d resnets, based on the implementations from torchvision: 2# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 3 4from typing import Any, Callable, List, Optional, Type, Union 5 6import torch 7import torch.nn as nn 8from torch import Tensor 9 10# from torchvision.models._api import WeightsEnum 11from torchvision.models._utils import _ovewrite_named_param 12# from torchvision.utils import _log_api_usage_once 13 14 15__all__ = [ 16 "ResNet3d", 17 "resnet3d_18", 18 "resnet3d_34", 19 "resnet3d_50", 20 "resnet3d_101", 21 "resnet3d_152", 22 "resnext3d_50_32x4d", 23 "resnext3d_101_32x8d", 24 "resnext3d_101_64x4d", 25 "wide_resnet3d_50_2", 26 "wide_resnet3d_101_2", 27] 28 29 30def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv3d: 31 """@private 32 """ 33 # 3x3 convolution with padding 34 return nn.Conv3d( 35 in_planes, 36 out_planes, 37 kernel_size=3, 38 stride=stride, 39 padding=dilation, 40 groups=groups, 41 bias=False, 42 dilation=dilation, 43 ) 44 45 46def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d: 47 """@private 48 """ 49 # 1x1 convolution 50 return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 51 52 53class BasicBlock(nn.Module): 54 """@private 55 """ 56 expansion: int = 1 57 58 def __init__( 59 self, 60 inplanes: int, 61 planes: int, 62 stride: int = 1, 63 downsample: Optional[nn.Module] = None, 64 groups: int = 1, 65 base_width: int = 64, 66 dilation: int = 1, 67 norm_layer: Optional[Callable[..., nn.Module]] = None, 68 ) -> None: 69 super().__init__() 70 if norm_layer is None: 71 norm_layer = nn.BatchNorm3d 72 if groups != 1 or base_width != 64: 73 raise ValueError("BasicBlock only supports groups=1 and base_width=64") 74 if dilation > 1: 75 raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 76 # Both self.conv1 and self.downsample layers downsample the input when stride != 1 77 self.conv1 = conv3x3(inplanes, planes, stride) 78 self.bn1 = norm_layer(planes) 79 self.relu = nn.ReLU(inplace=True) 80 self.conv2 = conv3x3(planes, planes) 81 self.bn2 = norm_layer(planes) 82 self.downsample = downsample 83 self.stride = stride 84 85 def forward(self, x: Tensor) -> Tensor: 86 identity = x 87 88 out = self.conv1(x) 89 out = self.bn1(out) 90 out = self.relu(out) 91 92 out = self.conv2(out) 93 out = self.bn2(out) 94 95 if self.downsample is not None: 96 identity = self.downsample(x) 97 98 out += identity 99 out = self.relu(out) 100 101 return out 102 103 104class Bottleneck(nn.Module): 105 """@private 106 """ 107 # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 108 # while original implementation places the stride at the first 1x1 convolution(self.conv1) 109 # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 110 # This variant is also known as ResNet V1.5 and improves accuracy according to 111 # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 112 113 expansion: int = 4 114 115 def __init__( 116 self, 117 inplanes: int, 118 planes: int, 119 stride: int = 1, 120 downsample: Optional[nn.Module] = None, 121 groups: int = 1, 122 base_width: int = 64, 123 dilation: int = 1, 124 norm_layer: Optional[Callable[..., nn.Module]] = None, 125 ) -> None: 126 super().__init__() 127 if norm_layer is None: 128 norm_layer = nn.BatchNorm3d 129 width = int(planes * (base_width / 64.0)) * groups 130 # Both self.conv2 and self.downsample layers downsample the input when stride != 1 131 self.conv1 = conv1x1(inplanes, width) 132 self.bn1 = norm_layer(width) 133 self.conv2 = conv3x3(width, width, stride, groups, dilation) 134 self.bn2 = norm_layer(width) 135 self.conv3 = conv1x1(width, planes * self.expansion) 136 self.bn3 = norm_layer(planes * self.expansion) 137 self.relu = nn.ReLU(inplace=True) 138 self.downsample = downsample 139 self.stride = stride 140 141 def forward(self, x: Tensor) -> Tensor: 142 identity = x 143 144 out = self.conv1(x) 145 out = self.bn1(out) 146 out = self.relu(out) 147 148 out = self.conv2(out) 149 out = self.bn2(out) 150 out = self.relu(out) 151 152 out = self.conv3(out) 153 out = self.bn3(out) 154 155 if self.downsample is not None: 156 identity = self.downsample(x) 157 158 out += identity 159 out = self.relu(out) 160 161 return out 162 163 164class ResNet3d(nn.Module): 165 """@private 166 """ 167 def __init__( 168 self, 169 block: Type[Union[BasicBlock, Bottleneck]], 170 layers: List[int], 171 in_channels: int, 172 out_channels: int, 173 zero_init_residual: bool = False, 174 groups: int = 1, 175 width_per_group: int = 64, 176 replace_stride_with_dilation: Optional[List[bool]] = None, 177 norm_layer: Optional[Callable[..., nn.Module]] = None, 178 stride_conv1: bool = True, 179 ) -> None: 180 super().__init__() 181 # _log_api_usage_once(self) 182 if norm_layer is None: 183 norm_layer = nn.BatchNorm3d 184 self._norm_layer = norm_layer 185 186 self.in_channels = in_channels 187 self.out_channels = out_channels 188 189 self.inplanes = 64 190 self.dilation = 1 191 if replace_stride_with_dilation is None: 192 # each element in the tuple indicates if we should replace 193 # the 2x2 stride with a dilated convolution instead 194 replace_stride_with_dilation = [False, False, False] 195 if len(replace_stride_with_dilation) != 3: 196 raise ValueError( 197 "replace_stride_with_dilation should be None " 198 f"or a 3-element tuple, got {replace_stride_with_dilation}" 199 ) 200 self.groups = groups 201 self.base_width = width_per_group 202 self.conv1 = nn.Conv3d( 203 in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False 204 ) 205 self.bn1 = norm_layer(self.inplanes) 206 self.relu = nn.ReLU(inplace=True) 207 self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 208 self.layer1 = self._make_layer(block, 64, layers[0]) 209 self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 210 self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 211 self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 212 self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 213 self.fc = nn.Linear(512 * block.expansion, out_channels) 214 215 for m in self.modules(): 216 if isinstance(m, nn.Conv3d): 217 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 218 elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 219 nn.init.constant_(m.weight, 1) 220 nn.init.constant_(m.bias, 0) 221 222 # Zero-initialize the last BN in each residual branch, 223 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 224 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 225 if zero_init_residual: 226 for m in self.modules(): 227 if isinstance(m, Bottleneck) and m.bn3.weight is not None: 228 nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 229 elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 230 nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 231 232 def _make_layer( 233 self, 234 block: Type[Union[BasicBlock, Bottleneck]], 235 planes: int, 236 blocks: int, 237 stride: int = 1, 238 dilate: bool = False, 239 ) -> nn.Sequential: 240 norm_layer = self._norm_layer 241 downsample = None 242 previous_dilation = self.dilation 243 if dilate: 244 self.dilation *= stride 245 stride = 1 246 if stride != 1 or self.inplanes != planes * block.expansion: 247 downsample = nn.Sequential( 248 conv1x1(self.inplanes, planes * block.expansion, stride), 249 norm_layer(planes * block.expansion), 250 ) 251 252 layers = [] 253 layers.append( 254 block( 255 self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 256 ) 257 ) 258 self.inplanes = planes * block.expansion 259 for _ in range(1, blocks): 260 layers.append( 261 block( 262 self.inplanes, 263 planes, 264 groups=self.groups, 265 base_width=self.base_width, 266 dilation=self.dilation, 267 norm_layer=norm_layer, 268 ) 269 ) 270 271 return nn.Sequential(*layers) 272 273 def _forward_impl(self, x: Tensor) -> Tensor: 274 # See note [TorchScript super()] 275 x = self.conv1(x) 276 x = self.bn1(x) 277 x = self.relu(x) 278 x = self.maxpool(x) 279 280 x = self.layer1(x) 281 x = self.layer2(x) 282 x = self.layer3(x) 283 x = self.layer4(x) 284 285 x = self.avgpool(x) 286 x = torch.flatten(x, 1) 287 x = self.fc(x) 288 289 return x 290 291 def forward(self, x: Tensor) -> Tensor: 292 return self._forward_impl(x) 293 294 295def _resnet( 296 block: Type[Union[BasicBlock, Bottleneck]], 297 layers: List[int], 298 weights: Any, 299 progress: bool, 300 **kwargs: Any, 301) -> ResNet3d: 302 if weights is not None: 303 _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 304 305 model = ResNet3d(block, layers, **kwargs) 306 307 if weights is not None: 308 model.load_state_dict(weights.get_state_dict(progress=progress)) 309 310 return model 311 312 313def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 314 """Get a residual network for 3d data with 18 layers. 315 316 The implementation of this network is based on torchvision: 317 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 318 319 Args: 320 in_channels: The number of input channels. 321 out_channels: The number of output channels. 322 kwargs: Additional keyword arguments for the ResNet. 323 324 Returns: 325 The 3D ResNet. 326 """ 327 return _resnet( 328 BasicBlock, [2, 2, 2, 2], weights=None, progress=False, 329 in_channels=in_channels, out_channels=out_channels, **kwargs 330 ) 331 332 333def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 334 """Get a residual network for 3d data with 34 layers. 335 336 The implementation of this network is based on torchvision: 337 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 338 339 Args: 340 in_channels: The number of input channels. 341 out_channels: The number of output channels. 342 kwargs: Additional keyword arguments for the ResNet. 343 344 Returns: 345 The 3D ResNet. 346 """ 347 return _resnet( 348 BasicBlock, [3, 4, 6, 3], weights=None, progress=False, 349 in_channels=in_channels, out_channels=out_channels, **kwargs 350 ) 351 352 353def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 354 """Get a residual network for 3d data with 50 layers. 355 356 The implementation of this network is based on torchvision: 357 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 358 359 Args: 360 in_channels: The number of input channels. 361 out_channels: The number of output channels. 362 kwargs: Additional keyword arguments for the ResNet. 363 364 Returns: 365 The 3D ResNet. 366 """ 367 return _resnet( 368 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 369 in_channels=in_channels, out_channels=out_channels, **kwargs 370 ) 371 372 373def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 374 """Get a residual network for 3d data with 101 layers. 375 376 The implementation of this network is based on torchvision: 377 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 378 379 Args: 380 in_channels: The number of input channels. 381 out_channels: The number of output channels. 382 kwargs: Additional keyword arguments for the ResNet. 383 384 Returns: 385 The 3D ResNet. 386 """ 387 return _resnet( 388 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 389 in_channels=in_channels, out_channels=out_channels, **kwargs 390 ) 391 392 393def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 394 """Get a residual network for 3d data with 152 layers. 395 396 The implementation of this network is based on torchvision: 397 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 398 399 Args: 400 in_channels: The number of input channels. 401 out_channels: The number of output channels. 402 kwargs: Additional keyword arguments for the ResNet. 403 404 Returns: 405 The 3D ResNet. 406 """ 407 return _resnet( 408 Bottleneck, [3, 8, 36, 3], weights=None, progress=False, 409 in_channels=in_channels, out_channels=out_channels, **kwargs 410 ) 411 412 413def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 414 """Get a residual network for 3d data with 50 layers and ResNext layer design. 415 416 The implementation of this network is based on torchvision: 417 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 418 419 Args: 420 in_channels: The number of input channels. 421 out_channels: The number of output channels. 422 kwargs: Additional keyword arguments for the ResNet. 423 424 Returns: 425 The 3D ResNext. 426 """ 427 _ovewrite_named_param(kwargs, "groups", 32) 428 _ovewrite_named_param(kwargs, "width_per_group", 4) 429 return _resnet( 430 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 431 in_channels=in_channels, out_channels=out_channels, **kwargs 432 ) 433 434 435def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 436 """Get a residual network for 3d data with 101 layers and ResNext layer design. 437 438 The implementation of this network is based on torchvision: 439 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 440 441 Args: 442 in_channels: The number of input channels. 443 out_channels: The number of output channels. 444 kwargs: Additional keyword arguments for the ResNet. 445 446 Returns: 447 The 3D ResNext. 448 """ 449 _ovewrite_named_param(kwargs, "groups", 32) 450 _ovewrite_named_param(kwargs, "width_per_group", 8) 451 return _resnet( 452 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 453 in_channels=in_channels, out_channels=out_channels, **kwargs 454 ) 455 456 457def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 458 """Get a residual network for 3d data with 101 layers and ResNext layer design. 459 460 The implementation of this network is based on torchvision: 461 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 462 463 Args: 464 in_channels: The number of input channels. 465 out_channels: The number of output channels. 466 kwargs: Additional keyword arguments for the ResNet. 467 468 Returns: 469 The 3D ResNext. 470 """ 471 _ovewrite_named_param(kwargs, "groups", 64) 472 _ovewrite_named_param(kwargs, "width_per_group", 4) 473 return _resnet( 474 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 475 in_channels=in_channels, out_channels=out_channels, **kwargs 476 ) 477 478 479def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 480 """Get a wide residual network for 3d data with 50 layers. 481 482 The implementation of this network is based on torchvision: 483 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 484 485 Args: 486 in_channels: The number of input channels. 487 out_channels: The number of output channels. 488 kwargs: Additional keyword arguments for the ResNet. 489 490 Returns: 491 The wide 3D ResNet. 492 """ 493 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 494 return _resnet( 495 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 496 in_channels=in_channels, out_channels=out_channels, **kwargs 497 ) 498 499 500def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 501 """Get a wide residual network for 3d data with 101 layers. 502 503 The implementation of this network is based on torchvision: 504 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 505 506 Args: 507 in_channels: The number of input channels. 508 out_channels: The number of output channels. 509 kwargs: Additional keyword arguments for the ResNet. 510 511 Returns: 512 The wide 3D ResNet. 513 """ 514 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 515 return _resnet( 516 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 517 in_channels=in_channels, out_channels=out_channels, **kwargs 518 )
314def resnet3d_18(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 315 """Get a residual network for 3d data with 18 layers. 316 317 The implementation of this network is based on torchvision: 318 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 319 320 Args: 321 in_channels: The number of input channels. 322 out_channels: The number of output channels. 323 kwargs: Additional keyword arguments for the ResNet. 324 325 Returns: 326 The 3D ResNet. 327 """ 328 return _resnet( 329 BasicBlock, [2, 2, 2, 2], weights=None, progress=False, 330 in_channels=in_channels, out_channels=out_channels, **kwargs 331 )
Get a residual network for 3d data with 18 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNet.
334def resnet3d_34(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 335 """Get a residual network for 3d data with 34 layers. 336 337 The implementation of this network is based on torchvision: 338 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 339 340 Args: 341 in_channels: The number of input channels. 342 out_channels: The number of output channels. 343 kwargs: Additional keyword arguments for the ResNet. 344 345 Returns: 346 The 3D ResNet. 347 """ 348 return _resnet( 349 BasicBlock, [3, 4, 6, 3], weights=None, progress=False, 350 in_channels=in_channels, out_channels=out_channels, **kwargs 351 )
Get a residual network for 3d data with 34 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNet.
354def resnet3d_50(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 355 """Get a residual network for 3d data with 50 layers. 356 357 The implementation of this network is based on torchvision: 358 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 359 360 Args: 361 in_channels: The number of input channels. 362 out_channels: The number of output channels. 363 kwargs: Additional keyword arguments for the ResNet. 364 365 Returns: 366 The 3D ResNet. 367 """ 368 return _resnet( 369 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 370 in_channels=in_channels, out_channels=out_channels, **kwargs 371 )
Get a residual network for 3d data with 50 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNet.
374def resnet3d_101(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 375 """Get a residual network for 3d data with 101 layers. 376 377 The implementation of this network is based on torchvision: 378 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 379 380 Args: 381 in_channels: The number of input channels. 382 out_channels: The number of output channels. 383 kwargs: Additional keyword arguments for the ResNet. 384 385 Returns: 386 The 3D ResNet. 387 """ 388 return _resnet( 389 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 390 in_channels=in_channels, out_channels=out_channels, **kwargs 391 )
Get a residual network for 3d data with 101 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNet.
394def resnet3d_152(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 395 """Get a residual network for 3d data with 152 layers. 396 397 The implementation of this network is based on torchvision: 398 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 399 400 Args: 401 in_channels: The number of input channels. 402 out_channels: The number of output channels. 403 kwargs: Additional keyword arguments for the ResNet. 404 405 Returns: 406 The 3D ResNet. 407 """ 408 return _resnet( 409 Bottleneck, [3, 8, 36, 3], weights=None, progress=False, 410 in_channels=in_channels, out_channels=out_channels, **kwargs 411 )
Get a residual network for 3d data with 152 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNet.
414def resnext3d_50_32x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 415 """Get a residual network for 3d data with 50 layers and ResNext layer design. 416 417 The implementation of this network is based on torchvision: 418 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 419 420 Args: 421 in_channels: The number of input channels. 422 out_channels: The number of output channels. 423 kwargs: Additional keyword arguments for the ResNet. 424 425 Returns: 426 The 3D ResNext. 427 """ 428 _ovewrite_named_param(kwargs, "groups", 32) 429 _ovewrite_named_param(kwargs, "width_per_group", 4) 430 return _resnet( 431 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 432 in_channels=in_channels, out_channels=out_channels, **kwargs 433 )
Get a residual network for 3d data with 50 layers and ResNext layer design.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNext.
436def resnext3d_101_32x8d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 437 """Get a residual network for 3d data with 101 layers and ResNext layer design. 438 439 The implementation of this network is based on torchvision: 440 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 441 442 Args: 443 in_channels: The number of input channels. 444 out_channels: The number of output channels. 445 kwargs: Additional keyword arguments for the ResNet. 446 447 Returns: 448 The 3D ResNext. 449 """ 450 _ovewrite_named_param(kwargs, "groups", 32) 451 _ovewrite_named_param(kwargs, "width_per_group", 8) 452 return _resnet( 453 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 454 in_channels=in_channels, out_channels=out_channels, **kwargs 455 )
Get a residual network for 3d data with 101 layers and ResNext layer design.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNext.
458def resnext3d_101_64x4d(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 459 """Get a residual network for 3d data with 101 layers and ResNext layer design. 460 461 The implementation of this network is based on torchvision: 462 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 463 464 Args: 465 in_channels: The number of input channels. 466 out_channels: The number of output channels. 467 kwargs: Additional keyword arguments for the ResNet. 468 469 Returns: 470 The 3D ResNext. 471 """ 472 _ovewrite_named_param(kwargs, "groups", 64) 473 _ovewrite_named_param(kwargs, "width_per_group", 4) 474 return _resnet( 475 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 476 in_channels=in_channels, out_channels=out_channels, **kwargs 477 )
Get a residual network for 3d data with 101 layers and ResNext layer design.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The 3D ResNext.
480def wide_resnet3d_50_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 481 """Get a wide residual network for 3d data with 50 layers. 482 483 The implementation of this network is based on torchvision: 484 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 485 486 Args: 487 in_channels: The number of input channels. 488 out_channels: The number of output channels. 489 kwargs: Additional keyword arguments for the ResNet. 490 491 Returns: 492 The wide 3D ResNet. 493 """ 494 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 495 return _resnet( 496 Bottleneck, [3, 4, 6, 3], weights=None, progress=False, 497 in_channels=in_channels, out_channels=out_channels, **kwargs 498 )
Get a wide residual network for 3d data with 50 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The wide 3D ResNet.
501def wide_resnet3d_101_2(in_channels: int, out_channels: int, **kwargs: Any) -> ResNet3d: 502 """Get a wide residual network for 3d data with 101 layers. 503 504 The implementation of this network is based on torchvision: 505 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 506 507 Args: 508 in_channels: The number of input channels. 509 out_channels: The number of output channels. 510 kwargs: Additional keyword arguments for the ResNet. 511 512 Returns: 513 The wide 3D ResNet. 514 """ 515 _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) 516 return _resnet( 517 Bottleneck, [3, 4, 23, 3], weights=None, progress=False, 518 in_channels=in_channels, out_channels=out_channels, **kwargs 519 )
Get a wide residual network for 3d data with 101 layers.
The implementation of this network is based on torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
Arguments:
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- kwargs: Additional keyword arguments for the ResNet.
Returns:
The wide 3D ResNet.