Module src.representation_learning.model_cl
Classes
class CL (in_channels=5, h_dim=128, projection_dim=32)
-
Expand source code
class CL(nn.Module): """ Contrastive Learning (CL) model using a customizable encoder and projector network. This model follows the SimCLR framework, where the input image is augmented to create two views, which are then encoded and projected to a latent space for contrastive loss calculation. Attributes: encoder (nn.Module): The feature extractor network. projector (nn.Sequential): The projection head to map features to latent space. h_dim (int): Dimension of the hidden representation from the encoder. base_size (int): The base image size after transformation. Methods: forward(x): Computes the latent representations of two augmented views. get_latent(x): Returns the latent representation from the encoder without projection. """ def __init__(self, in_channels=5, h_dim=128, projection_dim=32): """ Initializes the CL model. Parameters: in_channels (int): Number of input channels for the encoder (e.g., number of image channels). h_dim (int): Dimension of the encoder's output features. projection_dim (int): Dimension of the projected latent space. """ super(CL, self).__init__() # Encoder network to extract features from input images self.encoder = Encoder(input_channels=in_channels, output_features=h_dim) self.h_dim = h_dim self.base_size = 75 # Projector network to map the encoder output to the latent space self.projector = nn.Sequential( nn.Linear(h_dim, h_dim, bias=False), # Linear layer without bias nn.ReLU(), # Non-linear activation nn.Linear(h_dim, projection_dim, bias=False) # Final projection layer ) # Forward method to compute latent representations def forward(self, x): """ Performs forward pass through the CL model. Parameters: x (torch.Tensor): Input batch of images. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - z_i: Projected representation of the first augmented view. - z_j: Projected representation of the second augmented view. - h_i: Encoder output for the first augmented view. - h_j: Encoder output for the second augmented view. """ # Generate two augmented versions of the input transform = self.simclr_transform() x_i = transform(x) x_j = transform(x) # Encode both augmented views h_i = self.encoder(x_i) h_j = self.encoder(x_j) # Project the encoded features to latent space z_i = self.projector(h_i) z_j = self.projector(h_j) return z_i, z_j, h_i, h_j # Method to get latent representation without projection def get_latent(self, x): """ Returns the latent representation of the input image using the encoder. Parameters: x (torch.Tensor): Input batch of images. Returns: torch.Tensor: Latent representation from the encoder. """ return self.encoder(x) @staticmethod def loss(z_i, z_j, temperature): """ Computes the contrastive loss using normalized embeddings from two views (z_i and z_j). The loss follows the InfoNCE formulation commonly used in contrastive learning frameworks. Parameters: z_i (torch.Tensor): Embeddings from the first view of the batch, of shape (N, D), where N is the batch size and D is the embedding dimension. z_j (torch.Tensor): Embeddings from the second view of the batch, of shape (N, D). temperature (float): Temperature scaling parameter for contrastive loss. Returns: torch.Tensor: The contrastive loss as a single scalar tensor. """ # Get the batch size N = z_i.size(0) # Concatenate the embeddings from both views along the batch dimension (2N, D) z = torch.cat((z_i, z_j), dim=0) # Normalize the concatenated embeddings along the feature dimension z_normed = F.normalize(z, dim=1) # Compute the cosine similarity matrix (2N, 2N) cosine_similarity_matrix = torch.matmul(z_normed, z_normed.T) # Create ground-truth labels for positive pairs labels = torch.cat([torch.arange(N) for i in range(2)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() labels = labels.to(z.device) # Remove self-similarity from the similarity matrix (diagonal elements) mask = torch.eye(labels.shape[0], dtype=torch.bool).to(z.device) labels = labels[~mask].view(labels.shape[0], -1) cosine_similarity_matrix = cosine_similarity_matrix[~mask].view(cosine_similarity_matrix.shape[0], -1) # Extract positive and negative similarities positives = cosine_similarity_matrix[labels.bool()].view(labels.shape[0], -1) negatives = cosine_similarity_matrix[~labels.bool()].view(labels.shape[0], -1) # Concatenate positives and negatives for the logits logits = torch.cat([positives, negatives], dim=1) # Create target labels indicating positive pairs (0th column is positive) labels = torch.zeros(logits.shape[0], dtype=torch.long).to(z.device) # Apply temperature scaling logits = logits / temperature # Compute the cross-entropy loss loss = F.cross_entropy(logits, labels) return loss def simclr_transform(self): """Constructs the SimCLR data transformation pipeline.""" transformations = [] color_jitter = CustomColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2) #Adjust the color jitter parameters as needed, these are optimized for changes seen in our data transformations.append(transforms.RandomApply([color_jitter], p=0.5)) #apply color jitter with 50% probability transformations.append(transforms.RandomRotation(degrees=180)) #rotate the image anywhere between -180 and 180 degrees transformations.append(transforms.RandomHorizontalFlip(p=0.5)) #flip the image horizontally with 50% probability transformations.append(transforms.RandomVerticalFlip(p=0.5)) #flip the image vertically with 50% probability affine=transforms.RandomAffine(degrees=0, translate=(0.2,0.2)) #translate the image by up to 20% in both x and y directions to account for the fact that the cells are not always perfectly centered transformations.append(transforms.RandomApply([affine], p=0.5)) #apply affine transformation with 50% probability #OPTIONAL TRANSFORMATIONS #erode_dilate = RandomErodeDilateTransform(kernel_size=5, iterations=1) #transformations.append(transforms.RandomApply([erode_dilate], p=0.5)) #transformations.append(ZeroMask(p=0.5)) #transformations.append(OnesMask(p=0.5)) blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 3.0)) #blur the image with a kernel size of 3 and a sigma between 0.1 and 3.0 transformations.append(transforms.RandomApply([blur],p=0.75)) #apply blur with 75% probability, as we want to make sure that the model is robust to noise random_crop = transforms.RandomResizedCrop(size=self.base_size, scale=(0.5, 1.0)) #crop the image to a random size between 50% and 100% of the original size to account for variance in cell size within class transformations.append(transforms.RandomApply([random_crop], p=0.5)) #apply random crop with 50% probability #OPTIONAL TRANSFORMATIONS #if self.config.use_cutout: #transformations.append(Cutout(n_holes=1, length=32)) #if self.config.use_guassian_noise: #transformations.append(GaussianNoise(mean=0.0, std=0.1)) data_transforms = transforms.Compose(transformations) #combine all the transformations into a single transform return data_transforms
Contrastive Learning (CL) model using a customizable encoder and projector network. This model follows the SimCLR framework, where the input image is augmented to create two views, which are then encoded and projected to a latent space for contrastive loss calculation.
Attributes
encoder
:nn.Module
- The feature extractor network.
projector
:nn.Sequential
- The projection head to map features to latent space.
h_dim
:int
- Dimension of the hidden representation from the encoder.
base_size
:int
- The base image size after transformation.
Methods
forward(x): Computes the latent representations of two augmented views. get_latent(x): Returns the latent representation from the encoder without projection.
Initializes the CL model.
Parameters
in_channels (int): Number of input channels for the encoder (e.g., number of image channels). h_dim (int): Dimension of the encoder's output features. projection_dim (int): Dimension of the projected latent space.
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Static methods
def loss(z_i, z_j, temperature)
-
Expand source code
@staticmethod def loss(z_i, z_j, temperature): """ Computes the contrastive loss using normalized embeddings from two views (z_i and z_j). The loss follows the InfoNCE formulation commonly used in contrastive learning frameworks. Parameters: z_i (torch.Tensor): Embeddings from the first view of the batch, of shape (N, D), where N is the batch size and D is the embedding dimension. z_j (torch.Tensor): Embeddings from the second view of the batch, of shape (N, D). temperature (float): Temperature scaling parameter for contrastive loss. Returns: torch.Tensor: The contrastive loss as a single scalar tensor. """ # Get the batch size N = z_i.size(0) # Concatenate the embeddings from both views along the batch dimension (2N, D) z = torch.cat((z_i, z_j), dim=0) # Normalize the concatenated embeddings along the feature dimension z_normed = F.normalize(z, dim=1) # Compute the cosine similarity matrix (2N, 2N) cosine_similarity_matrix = torch.matmul(z_normed, z_normed.T) # Create ground-truth labels for positive pairs labels = torch.cat([torch.arange(N) for i in range(2)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() labels = labels.to(z.device) # Remove self-similarity from the similarity matrix (diagonal elements) mask = torch.eye(labels.shape[0], dtype=torch.bool).to(z.device) labels = labels[~mask].view(labels.shape[0], -1) cosine_similarity_matrix = cosine_similarity_matrix[~mask].view(cosine_similarity_matrix.shape[0], -1) # Extract positive and negative similarities positives = cosine_similarity_matrix[labels.bool()].view(labels.shape[0], -1) negatives = cosine_similarity_matrix[~labels.bool()].view(labels.shape[0], -1) # Concatenate positives and negatives for the logits logits = torch.cat([positives, negatives], dim=1) # Create target labels indicating positive pairs (0th column is positive) labels = torch.zeros(logits.shape[0], dtype=torch.long).to(z.device) # Apply temperature scaling logits = logits / temperature # Compute the cross-entropy loss loss = F.cross_entropy(logits, labels) return loss
Computes the contrastive loss using normalized embeddings from two views (z_i and z_j). The loss follows the InfoNCE formulation commonly used in contrastive learning frameworks.
Parameters
z_i (torch.Tensor): Embeddings from the first view of the batch, of shape (N, D), where N is the batch size and D is the embedding dimension. z_j (torch.Tensor): Embeddings from the second view of the batch, of shape (N, D). temperature (float): Temperature scaling parameter for contrastive loss.
Returns
torch.Tensor
- The contrastive loss as a single scalar tensor.
Methods
def forward(self, x) ‑> Callable[..., Any]
-
Expand source code
def forward(self, x): """ Performs forward pass through the CL model. Parameters: x (torch.Tensor): Input batch of images. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - z_i: Projected representation of the first augmented view. - z_j: Projected representation of the second augmented view. - h_i: Encoder output for the first augmented view. - h_j: Encoder output for the second augmented view. """ # Generate two augmented versions of the input transform = self.simclr_transform() x_i = transform(x) x_j = transform(x) # Encode both augmented views h_i = self.encoder(x_i) h_j = self.encoder(x_j) # Project the encoded features to latent space z_i = self.projector(h_i) z_j = self.projector(h_j) return z_i, z_j, h_i, h_j
Performs forward pass through the CL model.
Parameters
x (torch.Tensor): Input batch of images.
Returns
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - z_i: Projected representation of the first augmented view. - z_j: Projected representation of the second augmented view. - h_i: Encoder output for the first augmented view. - h_j: Encoder output for the second augmented view.
def get_latent(self, x)
-
Expand source code
def get_latent(self, x): """ Returns the latent representation of the input image using the encoder. Parameters: x (torch.Tensor): Input batch of images. Returns: torch.Tensor: Latent representation from the encoder. """ return self.encoder(x)
Returns the latent representation of the input image using the encoder.
Parameters
x (torch.Tensor): Input batch of images.
Returns
torch.Tensor
- Latent representation from the encoder.
def simclr_transform(self)
-
Expand source code
def simclr_transform(self): """Constructs the SimCLR data transformation pipeline.""" transformations = [] color_jitter = CustomColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2) #Adjust the color jitter parameters as needed, these are optimized for changes seen in our data transformations.append(transforms.RandomApply([color_jitter], p=0.5)) #apply color jitter with 50% probability transformations.append(transforms.RandomRotation(degrees=180)) #rotate the image anywhere between -180 and 180 degrees transformations.append(transforms.RandomHorizontalFlip(p=0.5)) #flip the image horizontally with 50% probability transformations.append(transforms.RandomVerticalFlip(p=0.5)) #flip the image vertically with 50% probability affine=transforms.RandomAffine(degrees=0, translate=(0.2,0.2)) #translate the image by up to 20% in both x and y directions to account for the fact that the cells are not always perfectly centered transformations.append(transforms.RandomApply([affine], p=0.5)) #apply affine transformation with 50% probability #OPTIONAL TRANSFORMATIONS #erode_dilate = RandomErodeDilateTransform(kernel_size=5, iterations=1) #transformations.append(transforms.RandomApply([erode_dilate], p=0.5)) #transformations.append(ZeroMask(p=0.5)) #transformations.append(OnesMask(p=0.5)) blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 3.0)) #blur the image with a kernel size of 3 and a sigma between 0.1 and 3.0 transformations.append(transforms.RandomApply([blur],p=0.75)) #apply blur with 75% probability, as we want to make sure that the model is robust to noise random_crop = transforms.RandomResizedCrop(size=self.base_size, scale=(0.5, 1.0)) #crop the image to a random size between 50% and 100% of the original size to account for variance in cell size within class transformations.append(transforms.RandomApply([random_crop], p=0.5)) #apply random crop with 50% probability #OPTIONAL TRANSFORMATIONS #if self.config.use_cutout: #transformations.append(Cutout(n_holes=1, length=32)) #if self.config.use_guassian_noise: #transformations.append(GaussianNoise(mean=0.0, std=0.1)) data_transforms = transforms.Compose(transformations) #combine all the transformations into a single transform return data_transforms
Constructs the SimCLR data transformation pipeline.
class Encoder (input_channels, output_features)
-
Expand source code
class Encoder(nn.Module): """ Encoder network for extracting latent features from input images. This encoder consists of multiple convolutional layers, each followed by batch normalization, ReLU activation, and pooling layers. The final representation is obtained via adaptive average pooling and a fully connected layer. Attributes: conv1, conv2, conv3, conv4 (nn.Conv2d): Convolutional layers for feature extraction. bn1, bn2, bn3, bn4 (nn.BatchNorm2d): Batch normalization layers for stabilizing training. pool (nn.MaxPool2d): Max pooling layer to reduce spatial dimensions. adap_pool (nn.AdaptiveAvgPool2d): Adaptive average pooling to generate fixed-size output. fc (nn.Linear): Fully connected layer to generate the final latent representation. Methods: forward(x): Forward pass through the encoder to generate feature representations. """ def __init__(self, input_channels, output_features): """ Initializes the encoder model. Parameters: input_channels (int): Number of input channels (e.g., number of color channels in an image). output_features (int): Number of output features (embedding dimension). """ super(Encoder, self).__init__() # First convolutional block: Conv -> BN -> ReLU -> Pool self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=(3, 3), stride=1, padding=1) self.bn1 = nn.BatchNorm2d(32) # Second convolutional block: Conv -> BN -> ReLU -> Pool self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(64) # Third convolutional block: Conv -> BN -> ReLU -> Pool self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(128) # Fourth convolutional block: Conv -> BN -> ReLU -> Pool self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) self.bn4 = nn.BatchNorm2d(256) # Pooling layers for downsampling self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # Max pooling to reduce spatial dimensions self.adap_pool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive pooling to output 1x1 spatial dimensions # Fully connected layer to produce final latent representation self.fc = nn.Linear(256, output_features) def forward(self, x): """ Forward pass through the encoder. Parameters: x (torch.Tensor): Input tensor of shape (batch_size, input_channels, height, width). Returns: torch.Tensor: Latent feature representation of shape (batch_size, output_features). """ # First block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn1(self.conv1(x))) # 75 x 75 x 5 -> 75 x 75 x 32 x = self.pool(x) # 75 x 75 x 32 -> 37 x 37 x 32 # Second block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn2(self.conv2(x))) # 37 x 37 x 32 -> 37 x 37 x 64 x = self.pool(x) # 37 x 37 x 64 -> 18 x 18 x 64 # Third block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn3(self.conv3(x))) # 18 x 18 x 64 -> 18 x 18 x 128 x = self.pool(x) # 18 x 18 x 128 -> 9 x 9 x 128 # Fourth block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn4(self.conv4(x))) # 9 x 9 x 128 -> 9 x 9 x 256 x = self.pool(x) # 9 x 9 x 256 -> 4 x 4 x 256 # Adaptive average pooling to 1x1 x = self.adap_pool(x) # 4 x 4 x 256 -> 1 x 1 x 256 # Flatten the spatial dimensions and pass through the fully connected layer x = torch.flatten(x, 1) # Flatten the 1x1x256 to a 256-dimensional vector x = self.fc(x) # Final feature representation (batch_size, output_features) return x
Encoder network for extracting latent features from input images. This encoder consists of multiple convolutional layers, each followed by batch normalization, ReLU activation, and pooling layers. The final representation is obtained via adaptive average pooling and a fully connected layer.
Attributes
- conv1, conv2, conv3, conv4 (nn.Conv2d): Convolutional layers for feature extraction.
- bn1, bn2, bn3, bn4 (nn.BatchNorm2d): Batch normalization layers for stabilizing training.
pool
:nn.MaxPool2d
- Max pooling layer to reduce spatial dimensions.
adap_pool
:nn.AdaptiveAvgPool2d
- Adaptive average pooling to generate fixed-size output.
fc
:nn.Linear
- Fully connected layer to generate the final latent representation.
Methods
forward(x): Forward pass through the encoder to generate feature representations.
Initializes the encoder model.
Parameters
input_channels (int): Number of input channels (e.g., number of color channels in an image). output_features (int): Number of output features (embedding dimension).
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def forward(self, x) ‑> Callable[..., Any]
-
Expand source code
def forward(self, x): """ Forward pass through the encoder. Parameters: x (torch.Tensor): Input tensor of shape (batch_size, input_channels, height, width). Returns: torch.Tensor: Latent feature representation of shape (batch_size, output_features). """ # First block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn1(self.conv1(x))) # 75 x 75 x 5 -> 75 x 75 x 32 x = self.pool(x) # 75 x 75 x 32 -> 37 x 37 x 32 # Second block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn2(self.conv2(x))) # 37 x 37 x 32 -> 37 x 37 x 64 x = self.pool(x) # 37 x 37 x 64 -> 18 x 18 x 64 # Third block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn3(self.conv3(x))) # 18 x 18 x 64 -> 18 x 18 x 128 x = self.pool(x) # 18 x 18 x 128 -> 9 x 9 x 128 # Fourth block: Conv -> BN -> ReLU -> Pool x = F.relu(self.bn4(self.conv4(x))) # 9 x 9 x 128 -> 9 x 9 x 256 x = self.pool(x) # 9 x 9 x 256 -> 4 x 4 x 256 # Adaptive average pooling to 1x1 x = self.adap_pool(x) # 4 x 4 x 256 -> 1 x 1 x 256 # Flatten the spatial dimensions and pass through the fully connected layer x = torch.flatten(x, 1) # Flatten the 1x1x256 to a 256-dimensional vector x = self.fc(x) # Final feature representation (batch_size, output_features) return x
Forward pass through the encoder.
Parameters
x (torch.Tensor): Input tensor of shape (batch_size, input_channels, height, width).
Returns
torch.Tensor
- Latent feature representation of shape (batch_size, output_features).