torch_em.model.unet
1import numpy as np 2import torch 3import torch.nn as nn 4 5 6# 7# Model Internal Post-processing 8# 9# Note: these are mainly for bioimage.io models, where postprocessing has to be done 10# inside of the model unless its defined in the general spec 11 12 13# TODO think about more (multicut-friendly) boundary postprocessing 14# e.g. max preserving smoothing: bd = np.maximum(bd, gaussian(bd, sigma=1)) 15class AccumulateChannels(nn.Module): 16 def __init__( 17 self, 18 invariant_channels, 19 accumulate_channels, 20 accumulator 21 ): 22 super().__init__() 23 self.invariant_channels = invariant_channels 24 self.accumulate_channels = accumulate_channels 25 assert accumulator in ("mean", "min", "max") 26 self.accumulator = getattr(torch, accumulator) 27 28 def _accumulate(self, x, c0, c1): 29 res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True) 30 if not torch.is_tensor(res): 31 res = res.values 32 assert torch.is_tensor(res) 33 return res 34 35 def forward(self, x): 36 if self.invariant_channels is None: 37 c0, c1 = self.accumulate_channels 38 return self._accumulate(x, c0, c1) 39 else: 40 i0, i1 = self.invariant_channels 41 c0, c1 = self.accumulate_channels 42 return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1) 43 44 45def affinities_to_boundaries(aff_channels, accumulator="max"): 46 return AccumulateChannels(None, aff_channels, accumulator) 47 48 49def affinities_with_foreground_to_boundaries(aff_channels, fg_channel=(0, 1), accumulator="max"): 50 return AccumulateChannels(fg_channel, aff_channels, accumulator) 51 52 53def affinities_to_boundaries2d(): 54 return affinities_to_boundaries((0, 2)) 55 56 57def affinities_with_foreground_to_boundaries2d(): 58 return affinities_with_foreground_to_boundaries((1, 3)) 59 60 61def affinities_to_boundaries3d(): 62 return affinities_to_boundaries((0, 3)) 63 64 65def affinities_with_foreground_to_boundaries3d(): 66 return affinities_with_foreground_to_boundaries((1, 4)) 67 68 69def affinities_to_boundaries_anisotropic(): 70 return AccumulateChannels(None, (1, 3), "max") 71 72 73POSTPROCESSING = { 74 "affinities_to_boundaries_anisotropic": affinities_to_boundaries_anisotropic, 75 "affinities_to_boundaries2d": affinities_to_boundaries2d, 76 "affinities_with_foreground_to_boundaries2d": affinities_with_foreground_to_boundaries2d, 77 "affinities_to_boundaries3d": affinities_to_boundaries3d, 78 "affinities_with_foreground_to_boundaries3d": affinities_with_foreground_to_boundaries3d, 79} 80 81 82# 83# Base Implementations 84# 85 86class UNetBase(nn.Module): 87 """ 88 """ 89 def __init__( 90 self, 91 encoder, 92 base, 93 decoder, 94 out_conv=None, 95 final_activation=None, 96 postprocessing=None, 97 check_shape=True, 98 ): 99 super().__init__() 100 if len(encoder) != len(decoder): 101 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 102 103 self.encoder = encoder 104 self.base = base 105 self.decoder = decoder 106 107 if out_conv is None: 108 self.return_decoder_outputs = False 109 self._out_channels = self.decoder.out_channels 110 elif isinstance(out_conv, nn.ModuleList): 111 if len(out_conv) != len(self.decoder): 112 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 113 self.return_decoder_outputs = True 114 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 115 else: 116 self.return_decoder_outputs = False 117 self._out_channels = out_conv.out_channels 118 self.out_conv = out_conv 119 self.check_shape = check_shape 120 self.final_activation = self._get_activation(final_activation) 121 self.postprocessing = self._get_postprocessing(postprocessing) 122 123 @property 124 def in_channels(self): 125 return self.encoder.in_channels 126 127 @property 128 def out_channels(self): 129 return self._out_channels 130 131 @property 132 def depth(self): 133 return len(self.encoder) 134 135 def _get_activation(self, activation): 136 return_activation = None 137 if activation is None: 138 return None 139 if isinstance(activation, nn.Module): 140 return activation 141 if isinstance(activation, str): 142 return_activation = getattr(nn, activation, None) 143 if return_activation is None: 144 raise ValueError(f"Invalid activation: {activation}") 145 return return_activation() 146 147 def _get_postprocessing(self, postprocessing): 148 if postprocessing is None: 149 return None 150 elif isinstance(postprocessing, nn.Module): 151 return postprocessing 152 elif postprocessing in POSTPROCESSING: 153 return POSTPROCESSING[postprocessing]() 154 else: 155 raise ValueError(f"Invalid postprocessing: {postprocessing}") 156 157 # load encoder / decoder / base states for pretraining 158 def load_encoder_state(self, state): 159 self.encoder.load_state_dict(state) 160 161 def load_decoder_state(self, state): 162 self.decoder.load_state_dict(state) 163 164 def load_base_state(self, state): 165 self.base.load_state_dict(state) 166 167 def _apply_default(self, x): 168 self.encoder.return_outputs = True 169 self.decoder.return_outputs = False 170 171 x, encoder_out = self.encoder(x) 172 x = self.base(x) 173 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 174 175 if self.out_conv is not None: 176 x = self.out_conv(x) 177 if self.final_activation is not None: 178 x = self.final_activation(x) 179 if self.postprocessing is not None: 180 x = self.postprocessing(x) 181 182 return x 183 184 def _apply_with_side_outputs(self, x): 185 self.encoder.return_outputs = True 186 self.decoder.return_outputs = True 187 188 x, encoder_out = self.encoder(x) 189 x = self.base(x) 190 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 191 192 x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)] 193 if self.final_activation is not None: 194 x = [self.final_activation(xx) for xx in x] 195 196 if self.postprocessing is not None: 197 x = [self.postprocessing(xx) for xx in x] 198 199 # we reverse the list to have the full shape output as first element 200 return x[::-1] 201 202 def _check_shape(self, x): 203 spatial_shape = tuple(x.shape)[2:] 204 depth = len(self.encoder) 205 factor = [2**depth] * len(spatial_shape) 206 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 207 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 208 raise ValueError(msg) 209 210 def forward(self, x): 211 # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference 212 # x = x.float() 213 if getattr(self, "check_shape", True): 214 self._check_shape(x) 215 if self.return_decoder_outputs: 216 return self._apply_with_side_outputs(x) 217 else: 218 return self._apply_default(x) 219 220 221def _update_conv_kwargs(kwargs, scale_factor): 222 # if the scale factor is a scalar or all entries are the same we don"t need to update the kwargs 223 if isinstance(scale_factor, int) or scale_factor.count(scale_factor[0]) == len(scale_factor): 224 return kwargs 225 else: # otherwise set anisotropic kernel 226 kernel_size = kwargs.get("kernel_size", 3) 227 padding = kwargs.get("padding", 1) 228 229 # bail out if kernel size or padding aren"t scalars, because it"s 230 # unclear what to do in this case 231 if not (isinstance(kernel_size, int) and isinstance(padding, int)): 232 return kwargs 233 234 kernel_size = tuple(1 if factor == 1 else kernel_size for factor in scale_factor) 235 padding = tuple(0 if factor == 1 else padding for factor in scale_factor) 236 kwargs.update({"kernel_size": kernel_size, "padding": padding}) 237 return kwargs 238 239 240class Encoder(nn.Module): 241 def __init__( 242 self, 243 features, 244 scale_factors, 245 conv_block_impl, 246 pooler_impl, 247 anisotropic_kernel=False, 248 **conv_block_kwargs 249 ): 250 super().__init__() 251 if len(features) != len(scale_factors) + 1: 252 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 253 254 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 255 if anisotropic_kernel: 256 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 257 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 258 259 self.blocks = nn.ModuleList( 260 [conv_block_impl(inc, outc, **kwargs) 261 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 262 ) 263 self.poolers = nn.ModuleList( 264 [pooler_impl(factor) for factor in scale_factors] 265 ) 266 self.return_outputs = True 267 268 self.in_channels = features[0] 269 self.out_channels = features[-1] 270 271 def __len__(self): 272 return len(self.blocks) 273 274 def forward(self, x): 275 encoder_out = [] 276 for block, pooler in zip(self.blocks, self.poolers): 277 x = block(x) 278 encoder_out.append(x) 279 x = pooler(x) 280 281 if self.return_outputs: 282 return x, encoder_out 283 else: 284 return x 285 286 287class Decoder(nn.Module): 288 def __init__( 289 self, 290 features, 291 scale_factors, 292 conv_block_impl, 293 sampler_impl, 294 anisotropic_kernel=False, 295 **conv_block_kwargs 296 ): 297 super().__init__() 298 if len(features) != len(scale_factors) + 1: 299 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 300 301 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 302 if anisotropic_kernel: 303 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 304 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 305 306 self.blocks = nn.ModuleList( 307 [conv_block_impl(inc, outc, **kwargs) 308 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 309 ) 310 self.samplers = nn.ModuleList( 311 [sampler_impl(factor, inc, outc) for factor, inc, outc 312 in zip(scale_factors, features[:-1], features[1:])] 313 ) 314 self.return_outputs = False 315 316 self.in_channels = features[0] 317 self.out_channels = features[-1] 318 319 def __len__(self): 320 return len(self.blocks) 321 322 # FIXME this prevents traces from being valid for other input sizes, need to find 323 # a solution to traceable cropping 324 def _crop(self, x, shape): 325 shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)] 326 crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)]) 327 return x[crop] 328 # # Implementation with torch.narrow, does not fix the tracing warnings! 329 # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)): 330 # x = torch.narrow(x, dim, sd, sh) 331 # return x 332 333 def _concat(self, x1, x2): 334 return torch.cat([x1, self._crop(x2, x1.shape)], dim=1) 335 336 def forward(self, x, encoder_inputs): 337 if len(encoder_inputs) != len(self.blocks): 338 raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}") 339 340 decoder_out = [] 341 for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs): 342 x = sampler(x) 343 x = block(self._concat(x, from_encoder)) 344 decoder_out.append(x) 345 346 if self.return_outputs: 347 return decoder_out + [x] 348 else: 349 return x 350 351 352def get_norm_layer(norm, dim, channels, n_groups=32): 353 if norm is None: 354 return None 355 if norm == "InstanceNorm": 356 kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01} 357 return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs) 358 elif norm == "OldDefault": 359 return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) 360 elif norm == "GroupNorm": 361 return nn.GroupNorm(min(n_groups, channels), channels) 362 elif norm == "BatchNorm": 363 return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels) 364 else: 365 raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}") 366 367 368class ConvBlock(nn.Module): 369 def __init__(self, in_channels, out_channels, dim, 370 kernel_size=3, padding=1, norm="InstanceNorm"): 371 super().__init__() 372 self.in_channels = in_channels 373 self.out_channels = out_channels 374 375 conv = nn.Conv2d if dim == 2 else nn.Conv3d 376 377 if norm is None: 378 self.block = nn.Sequential( 379 conv(in_channels, out_channels, 380 kernel_size=kernel_size, padding=padding), 381 nn.ReLU(inplace=True), 382 conv(out_channels, out_channels, 383 kernel_size=kernel_size, padding=padding), 384 nn.ReLU(inplace=True) 385 ) 386 else: 387 self.block = nn.Sequential( 388 get_norm_layer(norm, dim, in_channels), 389 conv(in_channels, out_channels, 390 kernel_size=kernel_size, padding=padding), 391 nn.ReLU(inplace=True), 392 get_norm_layer(norm, dim, out_channels), 393 conv(out_channels, out_channels, 394 kernel_size=kernel_size, padding=padding), 395 nn.ReLU(inplace=True) 396 ) 397 398 def forward(self, x): 399 return self.block(x) 400 401 402class Upsampler(nn.Module): 403 def __init__(self, scale_factor, 404 in_channels, out_channels, 405 dim, mode): 406 super().__init__() 407 self.mode = mode 408 self.scale_factor = scale_factor 409 410 conv = nn.Conv2d if dim == 2 else nn.Conv3d 411 self.conv = conv(in_channels, out_channels, 1) 412 413 def forward(self, x): 414 x = nn.functional.interpolate(x, scale_factor=self.scale_factor, 415 mode=self.mode, align_corners=False) 416 x = self.conv(x) 417 return x 418 419 420# 421# 2d unet implementations 422# 423 424class ConvBlock2d(ConvBlock): 425 def __init__(self, in_channels, out_channels, **kwargs): 426 super().__init__(in_channels, out_channels, dim=2, **kwargs) 427 428 429class Upsampler2d(Upsampler): 430 def __init__(self, scale_factor, 431 in_channels, out_channels, 432 mode="bilinear"): 433 super().__init__(scale_factor, in_channels, out_channels, 434 dim=2, mode=mode) 435 436 437class UNet2d(UNetBase): 438 def __init__( 439 self, 440 in_channels, 441 out_channels, 442 depth=4, 443 initial_features=32, 444 gain=2, 445 final_activation=None, 446 return_side_outputs=False, 447 conv_block_impl=ConvBlock2d, 448 pooler_impl=nn.MaxPool2d, 449 sampler_impl=Upsampler2d, 450 postprocessing=None, 451 check_shape=True, 452 **conv_block_kwargs, 453 ): 454 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 455 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 456 scale_factors = depth * [2] 457 458 if return_side_outputs: 459 if isinstance(out_channels, int) or out_channels is None: 460 out_channels = [out_channels] * depth 461 if len(out_channels) != depth: 462 raise ValueError() 463 out_conv = nn.ModuleList( 464 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 465 ) 466 else: 467 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 468 469 super().__init__( 470 encoder=Encoder( 471 features=features_encoder, 472 scale_factors=scale_factors, 473 conv_block_impl=conv_block_impl, 474 pooler_impl=pooler_impl, 475 **conv_block_kwargs 476 ), 477 decoder=Decoder( 478 features=features_decoder, 479 scale_factors=scale_factors[::-1], 480 conv_block_impl=conv_block_impl, 481 sampler_impl=sampler_impl, 482 **conv_block_kwargs 483 ), 484 base=conv_block_impl( 485 features_encoder[-1], features_encoder[-1] * gain, 486 **conv_block_kwargs 487 ), 488 out_conv=out_conv, 489 final_activation=final_activation, 490 postprocessing=postprocessing, 491 check_shape=check_shape, 492 ) 493 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 494 "initial_features": initial_features, "gain": gain, 495 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 496 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 497 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs} 498 499 500# 501# 3d unet implementations 502# 503 504class ConvBlock3d(ConvBlock): 505 def __init__(self, in_channels, out_channels, **kwargs): 506 super().__init__(in_channels, out_channels, dim=3, **kwargs) 507 508 509class Upsampler3d(Upsampler): 510 def __init__(self, scale_factor, 511 in_channels, out_channels, 512 mode="trilinear"): 513 super().__init__(scale_factor, in_channels, out_channels, 514 dim=3, mode=mode) 515 516 517class AnisotropicUNet(UNetBase): 518 def __init__( 519 self, 520 in_channels, 521 out_channels, 522 scale_factors, 523 initial_features=32, 524 gain=2, 525 final_activation=None, 526 return_side_outputs=False, 527 conv_block_impl=ConvBlock3d, 528 anisotropic_kernel=False, # TODO benchmark which option is better and set as default 529 postprocessing=None, 530 check_shape=True, 531 **conv_block_kwargs, 532 ): 533 depth = len(scale_factors) 534 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 535 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 536 537 if return_side_outputs: 538 if isinstance(out_channels, int) or out_channels is None: 539 out_channels = [out_channels] * depth 540 if len(out_channels) != depth: 541 raise ValueError() 542 out_conv = nn.ModuleList( 543 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 544 ) 545 else: 546 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 547 548 super().__init__( 549 encoder=Encoder( 550 features=features_encoder, 551 scale_factors=scale_factors, 552 conv_block_impl=conv_block_impl, 553 pooler_impl=nn.MaxPool3d, 554 anisotropic_kernel=anisotropic_kernel, 555 **conv_block_kwargs 556 ), 557 decoder=Decoder( 558 features=features_decoder, 559 scale_factors=scale_factors[::-1], 560 conv_block_impl=conv_block_impl, 561 sampler_impl=Upsampler3d, 562 anisotropic_kernel=anisotropic_kernel, 563 **conv_block_kwargs 564 ), 565 base=conv_block_impl( 566 features_encoder[-1], features_encoder[-1] * gain, 567 **conv_block_kwargs 568 ), 569 out_conv=out_conv, 570 final_activation=final_activation, 571 postprocessing=postprocessing, 572 check_shape=check_shape, 573 ) 574 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 575 "initial_features": initial_features, "gain": gain, 576 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 577 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 578 "postprocessing": postprocessing, **conv_block_kwargs} 579 580 def _check_shape(self, x): 581 spatial_shape = tuple(x.shape)[2:] 582 scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder)) 583 factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 584 if len(spatial_shape) != len(factor): 585 msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}" 586 raise ValueError(msg) 587 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 588 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 589 raise ValueError(msg) 590 591 592class UNet3d(AnisotropicUNet): 593 def __init__( 594 self, 595 in_channels, 596 out_channels, 597 depth=4, 598 initial_features=32, 599 gain=2, 600 final_activation=None, 601 return_side_outputs=False, 602 conv_block_impl=ConvBlock3d, 603 postprocessing=None, 604 check_shape=True, 605 **conv_block_kwargs, 606 ): 607 scale_factors = depth * [2] 608 super().__init__(in_channels, out_channels, scale_factors, 609 initial_features=initial_features, gain=gain, 610 final_activation=final_activation, 611 return_side_outputs=return_side_outputs, 612 anisotropic_kernel=False, 613 postprocessing=postprocessing, 614 conv_block_impl=conv_block_impl, 615 check_shape=check_shape, 616 **conv_block_kwargs) 617 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 618 "initial_features": initial_features, "gain": gain, 619 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 620 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
16class AccumulateChannels(nn.Module): 17 def __init__( 18 self, 19 invariant_channels, 20 accumulate_channels, 21 accumulator 22 ): 23 super().__init__() 24 self.invariant_channels = invariant_channels 25 self.accumulate_channels = accumulate_channels 26 assert accumulator in ("mean", "min", "max") 27 self.accumulator = getattr(torch, accumulator) 28 29 def _accumulate(self, x, c0, c1): 30 res = self.accumulator(x[:, c0:c1], dim=1, keepdim=True) 31 if not torch.is_tensor(res): 32 res = res.values 33 assert torch.is_tensor(res) 34 return res 35 36 def forward(self, x): 37 if self.invariant_channels is None: 38 c0, c1 = self.accumulate_channels 39 return self._accumulate(x, c0, c1) 40 else: 41 i0, i1 = self.invariant_channels 42 c0, c1 = self.accumulate_channels 43 return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
17 def __init__( 18 self, 19 invariant_channels, 20 accumulate_channels, 21 accumulator 22 ): 23 super().__init__() 24 self.invariant_channels = invariant_channels 25 self.accumulate_channels = accumulate_channels 26 assert accumulator in ("mean", "min", "max") 27 self.accumulator = getattr(torch, accumulator)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
36 def forward(self, x): 37 if self.invariant_channels is None: 38 c0, c1 = self.accumulate_channels 39 return self._accumulate(x, c0, c1) 40 else: 41 i0, i1 = self.invariant_channels 42 c0, c1 = self.accumulate_channels 43 return torch.cat([x[:, i0:i1], self._accumulate(x, c0, c1)], dim=1)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
87class UNetBase(nn.Module): 88 """ 89 """ 90 def __init__( 91 self, 92 encoder, 93 base, 94 decoder, 95 out_conv=None, 96 final_activation=None, 97 postprocessing=None, 98 check_shape=True, 99 ): 100 super().__init__() 101 if len(encoder) != len(decoder): 102 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 103 104 self.encoder = encoder 105 self.base = base 106 self.decoder = decoder 107 108 if out_conv is None: 109 self.return_decoder_outputs = False 110 self._out_channels = self.decoder.out_channels 111 elif isinstance(out_conv, nn.ModuleList): 112 if len(out_conv) != len(self.decoder): 113 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 114 self.return_decoder_outputs = True 115 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 116 else: 117 self.return_decoder_outputs = False 118 self._out_channels = out_conv.out_channels 119 self.out_conv = out_conv 120 self.check_shape = check_shape 121 self.final_activation = self._get_activation(final_activation) 122 self.postprocessing = self._get_postprocessing(postprocessing) 123 124 @property 125 def in_channels(self): 126 return self.encoder.in_channels 127 128 @property 129 def out_channels(self): 130 return self._out_channels 131 132 @property 133 def depth(self): 134 return len(self.encoder) 135 136 def _get_activation(self, activation): 137 return_activation = None 138 if activation is None: 139 return None 140 if isinstance(activation, nn.Module): 141 return activation 142 if isinstance(activation, str): 143 return_activation = getattr(nn, activation, None) 144 if return_activation is None: 145 raise ValueError(f"Invalid activation: {activation}") 146 return return_activation() 147 148 def _get_postprocessing(self, postprocessing): 149 if postprocessing is None: 150 return None 151 elif isinstance(postprocessing, nn.Module): 152 return postprocessing 153 elif postprocessing in POSTPROCESSING: 154 return POSTPROCESSING[postprocessing]() 155 else: 156 raise ValueError(f"Invalid postprocessing: {postprocessing}") 157 158 # load encoder / decoder / base states for pretraining 159 def load_encoder_state(self, state): 160 self.encoder.load_state_dict(state) 161 162 def load_decoder_state(self, state): 163 self.decoder.load_state_dict(state) 164 165 def load_base_state(self, state): 166 self.base.load_state_dict(state) 167 168 def _apply_default(self, x): 169 self.encoder.return_outputs = True 170 self.decoder.return_outputs = False 171 172 x, encoder_out = self.encoder(x) 173 x = self.base(x) 174 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 175 176 if self.out_conv is not None: 177 x = self.out_conv(x) 178 if self.final_activation is not None: 179 x = self.final_activation(x) 180 if self.postprocessing is not None: 181 x = self.postprocessing(x) 182 183 return x 184 185 def _apply_with_side_outputs(self, x): 186 self.encoder.return_outputs = True 187 self.decoder.return_outputs = True 188 189 x, encoder_out = self.encoder(x) 190 x = self.base(x) 191 x = self.decoder(x, encoder_inputs=encoder_out[::-1]) 192 193 x = [x if conv is None else conv(xx) for xx, conv in zip(x, self.out_conv)] 194 if self.final_activation is not None: 195 x = [self.final_activation(xx) for xx in x] 196 197 if self.postprocessing is not None: 198 x = [self.postprocessing(xx) for xx in x] 199 200 # we reverse the list to have the full shape output as first element 201 return x[::-1] 202 203 def _check_shape(self, x): 204 spatial_shape = tuple(x.shape)[2:] 205 depth = len(self.encoder) 206 factor = [2**depth] * len(spatial_shape) 207 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 208 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 209 raise ValueError(msg) 210 211 def forward(self, x): 212 # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference 213 # x = x.float() 214 if getattr(self, "check_shape", True): 215 self._check_shape(x) 216 if self.return_decoder_outputs: 217 return self._apply_with_side_outputs(x) 218 else: 219 return self._apply_default(x)
90 def __init__( 91 self, 92 encoder, 93 base, 94 decoder, 95 out_conv=None, 96 final_activation=None, 97 postprocessing=None, 98 check_shape=True, 99 ): 100 super().__init__() 101 if len(encoder) != len(decoder): 102 raise ValueError(f"Incompatible depth of encoder (depth={len(encoder)}) and decoder (depth={len(decoder)})") 103 104 self.encoder = encoder 105 self.base = base 106 self.decoder = decoder 107 108 if out_conv is None: 109 self.return_decoder_outputs = False 110 self._out_channels = self.decoder.out_channels 111 elif isinstance(out_conv, nn.ModuleList): 112 if len(out_conv) != len(self.decoder): 113 raise ValueError(f"Invalid length of out_conv, expected {len(decoder)}, got {len(out_conv)}") 114 self.return_decoder_outputs = True 115 self._out_channels = [None if conv is None else conv.out_channels for conv in out_conv] 116 else: 117 self.return_decoder_outputs = False 118 self._out_channels = out_conv.out_channels 119 self.out_conv = out_conv 120 self.check_shape = check_shape 121 self.final_activation = self._get_activation(final_activation) 122 self.postprocessing = self._get_postprocessing(postprocessing)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
211 def forward(self, x): 212 # cast input data to float, hotfix for modelzoo deployment issues, leaving it here for reference 213 # x = x.float() 214 if getattr(self, "check_shape", True): 215 self._check_shape(x) 216 if self.return_decoder_outputs: 217 return self._apply_with_side_outputs(x) 218 else: 219 return self._apply_default(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
241class Encoder(nn.Module): 242 def __init__( 243 self, 244 features, 245 scale_factors, 246 conv_block_impl, 247 pooler_impl, 248 anisotropic_kernel=False, 249 **conv_block_kwargs 250 ): 251 super().__init__() 252 if len(features) != len(scale_factors) + 1: 253 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 254 255 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 256 if anisotropic_kernel: 257 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 258 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 259 260 self.blocks = nn.ModuleList( 261 [conv_block_impl(inc, outc, **kwargs) 262 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 263 ) 264 self.poolers = nn.ModuleList( 265 [pooler_impl(factor) for factor in scale_factors] 266 ) 267 self.return_outputs = True 268 269 self.in_channels = features[0] 270 self.out_channels = features[-1] 271 272 def __len__(self): 273 return len(self.blocks) 274 275 def forward(self, x): 276 encoder_out = [] 277 for block, pooler in zip(self.blocks, self.poolers): 278 x = block(x) 279 encoder_out.append(x) 280 x = pooler(x) 281 282 if self.return_outputs: 283 return x, encoder_out 284 else: 285 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
242 def __init__( 243 self, 244 features, 245 scale_factors, 246 conv_block_impl, 247 pooler_impl, 248 anisotropic_kernel=False, 249 **conv_block_kwargs 250 ): 251 super().__init__() 252 if len(features) != len(scale_factors) + 1: 253 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 254 255 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 256 if anisotropic_kernel: 257 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 258 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 259 260 self.blocks = nn.ModuleList( 261 [conv_block_impl(inc, outc, **kwargs) 262 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 263 ) 264 self.poolers = nn.ModuleList( 265 [pooler_impl(factor) for factor in scale_factors] 266 ) 267 self.return_outputs = True 268 269 self.in_channels = features[0] 270 self.out_channels = features[-1]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
275 def forward(self, x): 276 encoder_out = [] 277 for block, pooler in zip(self.blocks, self.poolers): 278 x = block(x) 279 encoder_out.append(x) 280 x = pooler(x) 281 282 if self.return_outputs: 283 return x, encoder_out 284 else: 285 return x
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
288class Decoder(nn.Module): 289 def __init__( 290 self, 291 features, 292 scale_factors, 293 conv_block_impl, 294 sampler_impl, 295 anisotropic_kernel=False, 296 **conv_block_kwargs 297 ): 298 super().__init__() 299 if len(features) != len(scale_factors) + 1: 300 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 301 302 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 303 if anisotropic_kernel: 304 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 305 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 306 307 self.blocks = nn.ModuleList( 308 [conv_block_impl(inc, outc, **kwargs) 309 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 310 ) 311 self.samplers = nn.ModuleList( 312 [sampler_impl(factor, inc, outc) for factor, inc, outc 313 in zip(scale_factors, features[:-1], features[1:])] 314 ) 315 self.return_outputs = False 316 317 self.in_channels = features[0] 318 self.out_channels = features[-1] 319 320 def __len__(self): 321 return len(self.blocks) 322 323 # FIXME this prevents traces from being valid for other input sizes, need to find 324 # a solution to traceable cropping 325 def _crop(self, x, shape): 326 shape_diff = [(xsh - sh) // 2 for xsh, sh in zip(x.shape, shape)] 327 crop = tuple([slice(sd, xsh - sd) for sd, xsh in zip(shape_diff, x.shape)]) 328 return x[crop] 329 # # Implementation with torch.narrow, does not fix the tracing warnings! 330 # for dim, (sh, sd) in enumerate(zip(shape, shape_diff)): 331 # x = torch.narrow(x, dim, sd, sh) 332 # return x 333 334 def _concat(self, x1, x2): 335 return torch.cat([x1, self._crop(x2, x1.shape)], dim=1) 336 337 def forward(self, x, encoder_inputs): 338 if len(encoder_inputs) != len(self.blocks): 339 raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}") 340 341 decoder_out = [] 342 for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs): 343 x = sampler(x) 344 x = block(self._concat(x, from_encoder)) 345 decoder_out.append(x) 346 347 if self.return_outputs: 348 return decoder_out + [x] 349 else: 350 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
289 def __init__( 290 self, 291 features, 292 scale_factors, 293 conv_block_impl, 294 sampler_impl, 295 anisotropic_kernel=False, 296 **conv_block_kwargs 297 ): 298 super().__init__() 299 if len(features) != len(scale_factors) + 1: 300 raise ValueError("Incompatible number of features {len(features)} and scale_factors {len(scale_factors)}") 301 302 conv_kwargs = [conv_block_kwargs] * len(scale_factors) 303 if anisotropic_kernel: 304 conv_kwargs = [_update_conv_kwargs(kwargs, scale_factor) 305 for kwargs, scale_factor in zip(conv_kwargs, scale_factors)] 306 307 self.blocks = nn.ModuleList( 308 [conv_block_impl(inc, outc, **kwargs) 309 for inc, outc, kwargs in zip(features[:-1], features[1:], conv_kwargs)] 310 ) 311 self.samplers = nn.ModuleList( 312 [sampler_impl(factor, inc, outc) for factor, inc, outc 313 in zip(scale_factors, features[:-1], features[1:])] 314 ) 315 self.return_outputs = False 316 317 self.in_channels = features[0] 318 self.out_channels = features[-1]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
337 def forward(self, x, encoder_inputs): 338 if len(encoder_inputs) != len(self.blocks): 339 raise ValueError(f"Invalid number of encoder_inputs: expect {len(self.blocks)}, got {len(encoder_inputs)}") 340 341 decoder_out = [] 342 for block, sampler, from_encoder in zip(self.blocks, self.samplers, encoder_inputs): 343 x = sampler(x) 344 x = block(self._concat(x, from_encoder)) 345 decoder_out.append(x) 346 347 if self.return_outputs: 348 return decoder_out + [x] 349 else: 350 return x
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
353def get_norm_layer(norm, dim, channels, n_groups=32): 354 if norm is None: 355 return None 356 if norm == "InstanceNorm": 357 kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01} 358 return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs) 359 elif norm == "OldDefault": 360 return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) 361 elif norm == "GroupNorm": 362 return nn.GroupNorm(min(n_groups, channels), channels) 363 elif norm == "BatchNorm": 364 return nn.BatchNorm2d(channels) if dim == 2 else nn.BatchNorm3d(channels) 365 else: 366 raise ValueError(f"Invalid norm: expect one of 'InstanceNorm', 'BatchNorm' or 'GroupNorm', got {norm}")
369class ConvBlock(nn.Module): 370 def __init__(self, in_channels, out_channels, dim, 371 kernel_size=3, padding=1, norm="InstanceNorm"): 372 super().__init__() 373 self.in_channels = in_channels 374 self.out_channels = out_channels 375 376 conv = nn.Conv2d if dim == 2 else nn.Conv3d 377 378 if norm is None: 379 self.block = nn.Sequential( 380 conv(in_channels, out_channels, 381 kernel_size=kernel_size, padding=padding), 382 nn.ReLU(inplace=True), 383 conv(out_channels, out_channels, 384 kernel_size=kernel_size, padding=padding), 385 nn.ReLU(inplace=True) 386 ) 387 else: 388 self.block = nn.Sequential( 389 get_norm_layer(norm, dim, in_channels), 390 conv(in_channels, out_channels, 391 kernel_size=kernel_size, padding=padding), 392 nn.ReLU(inplace=True), 393 get_norm_layer(norm, dim, out_channels), 394 conv(out_channels, out_channels, 395 kernel_size=kernel_size, padding=padding), 396 nn.ReLU(inplace=True) 397 ) 398 399 def forward(self, x): 400 return self.block(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
370 def __init__(self, in_channels, out_channels, dim, 371 kernel_size=3, padding=1, norm="InstanceNorm"): 372 super().__init__() 373 self.in_channels = in_channels 374 self.out_channels = out_channels 375 376 conv = nn.Conv2d if dim == 2 else nn.Conv3d 377 378 if norm is None: 379 self.block = nn.Sequential( 380 conv(in_channels, out_channels, 381 kernel_size=kernel_size, padding=padding), 382 nn.ReLU(inplace=True), 383 conv(out_channels, out_channels, 384 kernel_size=kernel_size, padding=padding), 385 nn.ReLU(inplace=True) 386 ) 387 else: 388 self.block = nn.Sequential( 389 get_norm_layer(norm, dim, in_channels), 390 conv(in_channels, out_channels, 391 kernel_size=kernel_size, padding=padding), 392 nn.ReLU(inplace=True), 393 get_norm_layer(norm, dim, out_channels), 394 conv(out_channels, out_channels, 395 kernel_size=kernel_size, padding=padding), 396 nn.ReLU(inplace=True) 397 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
403class Upsampler(nn.Module): 404 def __init__(self, scale_factor, 405 in_channels, out_channels, 406 dim, mode): 407 super().__init__() 408 self.mode = mode 409 self.scale_factor = scale_factor 410 411 conv = nn.Conv2d if dim == 2 else nn.Conv3d 412 self.conv = conv(in_channels, out_channels, 1) 413 414 def forward(self, x): 415 x = nn.functional.interpolate(x, scale_factor=self.scale_factor, 416 mode=self.mode, align_corners=False) 417 x = self.conv(x) 418 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
404 def __init__(self, scale_factor, 405 in_channels, out_channels, 406 dim, mode): 407 super().__init__() 408 self.mode = mode 409 self.scale_factor = scale_factor 410 411 conv = nn.Conv2d if dim == 2 else nn.Conv3d 412 self.conv = conv(in_channels, out_channels, 1)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
414 def forward(self, x): 415 x = nn.functional.interpolate(x, scale_factor=self.scale_factor, 416 mode=self.mode, align_corners=False) 417 x = self.conv(x) 418 return x
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
425class ConvBlock2d(ConvBlock): 426 def __init__(self, in_channels, out_channels, **kwargs): 427 super().__init__(in_channels, out_channels, dim=2, **kwargs)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
426 def __init__(self, in_channels, out_channels, **kwargs): 427 super().__init__(in_channels, out_channels, dim=2, **kwargs)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
430class Upsampler2d(Upsampler): 431 def __init__(self, scale_factor, 432 in_channels, out_channels, 433 mode="bilinear"): 434 super().__init__(scale_factor, in_channels, out_channels, 435 dim=2, mode=mode)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
431 def __init__(self, scale_factor, 432 in_channels, out_channels, 433 mode="bilinear"): 434 super().__init__(scale_factor, in_channels, out_channels, 435 dim=2, mode=mode)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
438class UNet2d(UNetBase): 439 def __init__( 440 self, 441 in_channels, 442 out_channels, 443 depth=4, 444 initial_features=32, 445 gain=2, 446 final_activation=None, 447 return_side_outputs=False, 448 conv_block_impl=ConvBlock2d, 449 pooler_impl=nn.MaxPool2d, 450 sampler_impl=Upsampler2d, 451 postprocessing=None, 452 check_shape=True, 453 **conv_block_kwargs, 454 ): 455 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 456 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 457 scale_factors = depth * [2] 458 459 if return_side_outputs: 460 if isinstance(out_channels, int) or out_channels is None: 461 out_channels = [out_channels] * depth 462 if len(out_channels) != depth: 463 raise ValueError() 464 out_conv = nn.ModuleList( 465 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 466 ) 467 else: 468 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 469 470 super().__init__( 471 encoder=Encoder( 472 features=features_encoder, 473 scale_factors=scale_factors, 474 conv_block_impl=conv_block_impl, 475 pooler_impl=pooler_impl, 476 **conv_block_kwargs 477 ), 478 decoder=Decoder( 479 features=features_decoder, 480 scale_factors=scale_factors[::-1], 481 conv_block_impl=conv_block_impl, 482 sampler_impl=sampler_impl, 483 **conv_block_kwargs 484 ), 485 base=conv_block_impl( 486 features_encoder[-1], features_encoder[-1] * gain, 487 **conv_block_kwargs 488 ), 489 out_conv=out_conv, 490 final_activation=final_activation, 491 postprocessing=postprocessing, 492 check_shape=check_shape, 493 ) 494 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 495 "initial_features": initial_features, "gain": gain, 496 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 497 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 498 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
439 def __init__( 440 self, 441 in_channels, 442 out_channels, 443 depth=4, 444 initial_features=32, 445 gain=2, 446 final_activation=None, 447 return_side_outputs=False, 448 conv_block_impl=ConvBlock2d, 449 pooler_impl=nn.MaxPool2d, 450 sampler_impl=Upsampler2d, 451 postprocessing=None, 452 check_shape=True, 453 **conv_block_kwargs, 454 ): 455 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 456 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 457 scale_factors = depth * [2] 458 459 if return_side_outputs: 460 if isinstance(out_channels, int) or out_channels is None: 461 out_channels = [out_channels] * depth 462 if len(out_channels) != depth: 463 raise ValueError() 464 out_conv = nn.ModuleList( 465 [nn.Conv2d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 466 ) 467 else: 468 out_conv = None if out_channels is None else nn.Conv2d(features_decoder[-1], out_channels, 1) 469 470 super().__init__( 471 encoder=Encoder( 472 features=features_encoder, 473 scale_factors=scale_factors, 474 conv_block_impl=conv_block_impl, 475 pooler_impl=pooler_impl, 476 **conv_block_kwargs 477 ), 478 decoder=Decoder( 479 features=features_decoder, 480 scale_factors=scale_factors[::-1], 481 conv_block_impl=conv_block_impl, 482 sampler_impl=sampler_impl, 483 **conv_block_kwargs 484 ), 485 base=conv_block_impl( 486 features_encoder[-1], features_encoder[-1] * gain, 487 **conv_block_kwargs 488 ), 489 out_conv=out_conv, 490 final_activation=final_activation, 491 postprocessing=postprocessing, 492 check_shape=check_shape, 493 ) 494 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 495 "initial_features": initial_features, "gain": gain, 496 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 497 "conv_block_impl": conv_block_impl, "pooler_impl": pooler_impl, 498 "sampler_impl": sampler_impl, "postprocessing": postprocessing, **conv_block_kwargs}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- UNetBase
- encoder
- base
- decoder
- out_conv
- check_shape
- final_activation
- postprocessing
- in_channels
- out_channels
- depth
- load_encoder_state
- load_decoder_state
- load_base_state
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
505class ConvBlock3d(ConvBlock): 506 def __init__(self, in_channels, out_channels, **kwargs): 507 super().__init__(in_channels, out_channels, dim=3, **kwargs)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
506 def __init__(self, in_channels, out_channels, **kwargs): 507 super().__init__(in_channels, out_channels, dim=3, **kwargs)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
510class Upsampler3d(Upsampler): 511 def __init__(self, scale_factor, 512 in_channels, out_channels, 513 mode="trilinear"): 514 super().__init__(scale_factor, in_channels, out_channels, 515 dim=3, mode=mode)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
511 def __init__(self, scale_factor, 512 in_channels, out_channels, 513 mode="trilinear"): 514 super().__init__(scale_factor, in_channels, out_channels, 515 dim=3, mode=mode)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
518class AnisotropicUNet(UNetBase): 519 def __init__( 520 self, 521 in_channels, 522 out_channels, 523 scale_factors, 524 initial_features=32, 525 gain=2, 526 final_activation=None, 527 return_side_outputs=False, 528 conv_block_impl=ConvBlock3d, 529 anisotropic_kernel=False, # TODO benchmark which option is better and set as default 530 postprocessing=None, 531 check_shape=True, 532 **conv_block_kwargs, 533 ): 534 depth = len(scale_factors) 535 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 536 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 537 538 if return_side_outputs: 539 if isinstance(out_channels, int) or out_channels is None: 540 out_channels = [out_channels] * depth 541 if len(out_channels) != depth: 542 raise ValueError() 543 out_conv = nn.ModuleList( 544 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 545 ) 546 else: 547 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 548 549 super().__init__( 550 encoder=Encoder( 551 features=features_encoder, 552 scale_factors=scale_factors, 553 conv_block_impl=conv_block_impl, 554 pooler_impl=nn.MaxPool3d, 555 anisotropic_kernel=anisotropic_kernel, 556 **conv_block_kwargs 557 ), 558 decoder=Decoder( 559 features=features_decoder, 560 scale_factors=scale_factors[::-1], 561 conv_block_impl=conv_block_impl, 562 sampler_impl=Upsampler3d, 563 anisotropic_kernel=anisotropic_kernel, 564 **conv_block_kwargs 565 ), 566 base=conv_block_impl( 567 features_encoder[-1], features_encoder[-1] * gain, 568 **conv_block_kwargs 569 ), 570 out_conv=out_conv, 571 final_activation=final_activation, 572 postprocessing=postprocessing, 573 check_shape=check_shape, 574 ) 575 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 576 "initial_features": initial_features, "gain": gain, 577 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 578 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 579 "postprocessing": postprocessing, **conv_block_kwargs} 580 581 def _check_shape(self, x): 582 spatial_shape = tuple(x.shape)[2:] 583 scale_factors = self.init_kwargs.get("scale_factors", [[2, 2, 2]]*len(self.encoder)) 584 factor = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] 585 if len(spatial_shape) != len(factor): 586 msg = f"Invalid shape for U-Net: dimensions don't agree {len(spatial_shape)} != {len(factor)}" 587 raise ValueError(msg) 588 if any(sh % fac != 0 for sh, fac in zip(spatial_shape, factor)): 589 msg = f"Invalid shape for U-Net: {spatial_shape} is not divisible by {factor}" 590 raise ValueError(msg)
519 def __init__( 520 self, 521 in_channels, 522 out_channels, 523 scale_factors, 524 initial_features=32, 525 gain=2, 526 final_activation=None, 527 return_side_outputs=False, 528 conv_block_impl=ConvBlock3d, 529 anisotropic_kernel=False, # TODO benchmark which option is better and set as default 530 postprocessing=None, 531 check_shape=True, 532 **conv_block_kwargs, 533 ): 534 depth = len(scale_factors) 535 features_encoder = [in_channels] + [initial_features * gain ** i for i in range(depth)] 536 features_decoder = [initial_features * gain ** i for i in range(depth + 1)][::-1] 537 538 if return_side_outputs: 539 if isinstance(out_channels, int) or out_channels is None: 540 out_channels = [out_channels] * depth 541 if len(out_channels) != depth: 542 raise ValueError() 543 out_conv = nn.ModuleList( 544 [nn.Conv3d(feat, outc, 1) for feat, outc in zip(features_decoder[1:], out_channels)] 545 ) 546 else: 547 out_conv = None if out_channels is None else nn.Conv3d(features_decoder[-1], out_channels, 1) 548 549 super().__init__( 550 encoder=Encoder( 551 features=features_encoder, 552 scale_factors=scale_factors, 553 conv_block_impl=conv_block_impl, 554 pooler_impl=nn.MaxPool3d, 555 anisotropic_kernel=anisotropic_kernel, 556 **conv_block_kwargs 557 ), 558 decoder=Decoder( 559 features=features_decoder, 560 scale_factors=scale_factors[::-1], 561 conv_block_impl=conv_block_impl, 562 sampler_impl=Upsampler3d, 563 anisotropic_kernel=anisotropic_kernel, 564 **conv_block_kwargs 565 ), 566 base=conv_block_impl( 567 features_encoder[-1], features_encoder[-1] * gain, 568 **conv_block_kwargs 569 ), 570 out_conv=out_conv, 571 final_activation=final_activation, 572 postprocessing=postprocessing, 573 check_shape=check_shape, 574 ) 575 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "scale_factors": scale_factors, 576 "initial_features": initial_features, "gain": gain, 577 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 578 "conv_block_impl": conv_block_impl, "anisotropic_kernel": anisotropic_kernel, 579 "postprocessing": postprocessing, **conv_block_kwargs}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- UNetBase
- encoder
- base
- decoder
- out_conv
- check_shape
- final_activation
- postprocessing
- in_channels
- out_channels
- depth
- load_encoder_state
- load_decoder_state
- load_base_state
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
593class UNet3d(AnisotropicUNet): 594 def __init__( 595 self, 596 in_channels, 597 out_channels, 598 depth=4, 599 initial_features=32, 600 gain=2, 601 final_activation=None, 602 return_side_outputs=False, 603 conv_block_impl=ConvBlock3d, 604 postprocessing=None, 605 check_shape=True, 606 **conv_block_kwargs, 607 ): 608 scale_factors = depth * [2] 609 super().__init__(in_channels, out_channels, scale_factors, 610 initial_features=initial_features, gain=gain, 611 final_activation=final_activation, 612 return_side_outputs=return_side_outputs, 613 anisotropic_kernel=False, 614 postprocessing=postprocessing, 615 conv_block_impl=conv_block_impl, 616 check_shape=check_shape, 617 **conv_block_kwargs) 618 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 619 "initial_features": initial_features, "gain": gain, 620 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 621 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
594 def __init__( 595 self, 596 in_channels, 597 out_channels, 598 depth=4, 599 initial_features=32, 600 gain=2, 601 final_activation=None, 602 return_side_outputs=False, 603 conv_block_impl=ConvBlock3d, 604 postprocessing=None, 605 check_shape=True, 606 **conv_block_kwargs, 607 ): 608 scale_factors = depth * [2] 609 super().__init__(in_channels, out_channels, scale_factors, 610 initial_features=initial_features, gain=gain, 611 final_activation=final_activation, 612 return_side_outputs=return_side_outputs, 613 anisotropic_kernel=False, 614 postprocessing=postprocessing, 615 conv_block_impl=conv_block_impl, 616 check_shape=check_shape, 617 **conv_block_kwargs) 618 self.init_kwargs = {"in_channels": in_channels, "out_channels": out_channels, "depth": depth, 619 "initial_features": initial_features, "gain": gain, 620 "final_activation": final_activation, "return_side_outputs": return_side_outputs, 621 "conv_block_impl": conv_block_impl, "postprocessing": postprocessing, **conv_block_kwargs}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- UNetBase
- encoder
- base
- decoder
- out_conv
- check_shape
- final_activation
- postprocessing
- in_channels
- out_channels
- depth
- load_encoder_state
- load_decoder_state
- load_base_state
- forward
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile