ResNet18MultiViewImageCondition

CLASS cleandiffuser.nn_diffusion.ResNet18MultiViewImageCondition(image_sz: int, in_channel: int, emb_dim: int, n_views: int, act_fn: Callable[[], nn.Module] = nn.ReLU, use_group_norm: bool = True, group_channels: int = 16, use_spatial_softmax: bool = True, dropout: float = 0.0) [SOURCE]

A ResNet18 for image condition. It encodes the input image into a fixed-size embedding. The implementation is adapted from DiffusionPolicy. Compared to the original ResNet18, we replace BatchNorm2d with GroupNorm, and use a SpatialSoftmax instead of an average pooling layer. The Multi-view version uses different ResNet18 networks for each view.

Parameters:

  • image_sz (int): Size of the input image. The image is assumed to be square.
  • in_channel (int): Number of input channels. 3 for RGB images.
  • emb_dim (int): Dimension of the output embedding.
  • n_views (int): Number of views.
  • act_fn (Callable[[], nn.Module]): Activation function to use in the network. Default is ReLU.
  • use_group_norm (bool): Whether to use GroupNorm instead of BatchNorm. Default is True.
  • group_channels (int): Number of channels per group in GroupNorm. Default is 16.
  • use_spatial_softmax (bool): Whether to use SpatialSoftmax instead of average pooling. Default is True.
  • dropout (float): Condition Dropout rate. Default is 0.0.

forward(condition: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor

Parameters:

  • condition (torch.Tensor): The context tensor in shape (b, v, n, c, h, w) or (b, v, c, h, w). Here, v is the number of different views, n is sequence length, and (c, h, w)` is the shape of the image.
  • mask (Optional[torch.Tensor]): The mask tensor. Default is None. None means no mask.

Returns:

  • torch.Tensor: The output tensor in shape (b, v, n, emb_dim) or (b, v, emb_dim). Each element in the batch has a probability of dropout to be zeros.