ResNet18MultiViewImageCondition
CLASS cleandiffuser.nn_diffusion.ResNet18MultiViewImageCondition(image_sz: int, in_channel: int, emb_dim: int, n_views: int, act_fn: Callable[[], nn.Module] = nn.ReLU, use_group_norm: bool = True, group_channels: int = 16, use_spatial_softmax: bool = True, dropout: float = 0.0) [SOURCE]
A ResNet18 for image condition. It encodes the input image into a fixed-size embedding. The implementation is adapted from DiffusionPolicy. Compared to the original ResNet18, we replace BatchNorm2d with GroupNorm, and use a SpatialSoftmax instead of an average pooling layer. The Multi-view version uses different ResNet18 networks for each view.
Parameters:
- image_sz (int): Size of the input image. The image is assumed to be square.
 - in_channel (int): Number of input channels. 3 for RGB images.
 - emb_dim (int): Dimension of the output embedding.
 - n_views (int): Number of views.
 - act_fn (Callable[[], nn.Module]): Activation function to use in the network. Default is ReLU.
 - use_group_norm (bool): Whether to use GroupNorm instead of BatchNorm. Default is True.
 - group_channels (int): Number of channels per group in GroupNorm. Default is 16.
 - use_spatial_softmax (bool): Whether to use SpatialSoftmax instead of average pooling. Default is True.
 - dropout (float): Condition Dropout rate. Default is 0.0.
 
forward(condition: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor
Parameters:
- condition (torch.Tensor): The context tensor in shape 
(b, v, n, c, h, w)or(b, v, c, h, w). Here,vis the number of different views,nis sequence length, and (c, h, w)` is the shape of the image. - mask (Optional[torch.Tensor]): The mask tensor. Default is None. None means no mask.
 
Returns:
- torch.Tensor: The output tensor in shape 
(b, v, n, emb_dim)or(b, v, emb_dim). Each element in the batch has a probability ofdropoutto be zeros.