ResNet18MultiViewImageCondition
CLASS cleandiffuser.nn_diffusion.ResNet18MultiViewImageCondition(image_sz: int, in_channel: int, emb_dim: int, n_views: int, act_fn: Callable[[], nn.Module] = nn.ReLU, use_group_norm: bool = True, group_channels: int = 16, use_spatial_softmax: bool = True, dropout: float = 0.0) [SOURCE]
A ResNet18 for image condition. It encodes the input image into a fixed-size embedding. The implementation is adapted from DiffusionPolicy. Compared to the original ResNet18, we replace BatchNorm2d
with GroupNorm
, and use a SpatialSoftmax instead of an average pooling layer. The Multi-view version uses different ResNet18 networks for each view.
Parameters:
- image_sz (int): Size of the input image. The image is assumed to be square.
- in_channel (int): Number of input channels. 3 for RGB images.
- emb_dim (int): Dimension of the output embedding.
- n_views (int): Number of views.
- act_fn (Callable[[], nn.Module]): Activation function to use in the network. Default is ReLU.
- use_group_norm (bool): Whether to use GroupNorm instead of BatchNorm. Default is True.
- group_channels (int): Number of channels per group in GroupNorm. Default is 16.
- use_spatial_softmax (bool): Whether to use SpatialSoftmax instead of average pooling. Default is True.
- dropout (float): Condition Dropout rate. Default is 0.0.
forward(condition: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor
Parameters:
- condition (torch.Tensor): The context tensor in shape
(b, v, n, c, h, w)
or(b, v, c, h, w)
. Here,v
is the number of different views,n
is sequence length, and (c, h, w)` is the shape of the image. - mask (Optional[torch.Tensor]): The mask tensor. Default is None. None means no mask.
Returns:
- torch.Tensor: The output tensor in shape
(b, v, n, emb_dim)
or(b, v, emb_dim)
. Each element in the batch has a probability ofdropout
to be zeros.