Module src.representation_learning.cl_transforms

Classes

class CustomColorJitter (brightness=0, contrast=0, saturation=0, hue=0)
Expand source code
class CustomColorJitter:
    """
    Applies color jitter to the first 4 channels of each image in a batch.
    Retains the mask channel (if present) without modification.

    Attributes:
        color_jitter (transforms.ColorJitter): 
            Transformation to apply brightness, contrast, saturation, and hue changes.
    """

    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        """
        Initializes the color jitter transformation.

        Parameters:
            brightness (float): Brightness factor.
            contrast (float): Contrast factor.
            saturation (float): Saturation factor.
            hue (float): Hue factor.
        """
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, imgs):
        """
        Applies color jitter to each image in the batch.

        Parameters:
            imgs (Tensor): Batch of images with shape [B, C, H, W].

        Returns:
            Tensor: Batch of jittered images.
        """
        batch_jittered = []
        for img in imgs:
            channels = []
            num_channels = img.shape[0]
            # Apply jitter to the first 4 channels only
            for i in range(min(num_channels, 4)):
                single_channel = img[i].unsqueeze(0)
                jittered_channel = self.color_jitter(single_channel)
                channels.append(jittered_channel.squeeze(0))
            # Retain the mask channel if present
            if num_channels == 5:
                mask_channel = img[4].unsqueeze(0)
                channels.append(mask_channel.squeeze(0))
            jittered_img = torch.stack(channels, dim=0)
            batch_jittered.append(jittered_img)
        return torch.stack(batch_jittered, dim=0)

Applies color jitter to the first 4 channels of each image in a batch. Retains the mask channel (if present) without modification.

Attributes

color_jitter : transforms.ColorJitter
Transformation to apply brightness, contrast, saturation, and hue changes.

Initializes the color jitter transformation.

Parameters

brightness (float): Brightness factor. contrast (float): Contrast factor. saturation (float): Saturation factor. hue (float): Hue factor.

class Cutout (n_holes, length)
Expand source code
class Cutout(object):
    """
    Applies random square cutouts (holes) to each image in a batch.

    Attributes:
        n_holes (int): Number of holes to cut out from each image.
        length (int): Length of each square hole.
    """

    def __init__(self, n_holes, length):
        """
        Initializes the cutout transformation.

        Parameters:
            n_holes (int): Number of square holes.
            length (int): Length of each hole's side.
        """
        self.n_holes = n_holes
        self.length = length

    def __call__(self, imgs):
        """
        Applies cutout to each image in the batch.

        Parameters:
            imgs (Tensor): Batch of images with shape [B, C, H, W].

        Returns:
            Tensor: Batch of images with cutout applied.
        """
        batch_size, _, h, w = imgs.size()
        for i in range(batch_size):
            img = imgs[i]
            mask = np.ones((h, w), np.float32)
            for _ in range(self.n_holes):
                y, x = np.random.randint(h), np.random.randint(w)
                y1 = np.clip(y - self.length // 2, 0, h)
                x1 = np.clip(x - self.length // 2, 0, w)
                y2 = np.clip(y + self.length // 2, 0, h)
                x2 = np.clip(x + self.length // 2, 0, w)
                mask[y1: y2, x1: x2] = 0
            mask = torch.from_numpy(mask).to(imgs.device)
            mask = mask.expand_as(img)
            imgs[i] *= mask
        return imgs

    def __repr__(self):
        return f"{self.__class__.__name__}(n_holes={self.n_holes}, length={self.length})"

Applies random square cutouts (holes) to each image in a batch.

Attributes

n_holes : int
Number of holes to cut out from each image.
length : int
Length of each square hole.

Initializes the cutout transformation.

Parameters

n_holes (int): Number of square holes. length (int): Length of each hole's side.

class GaussianNoise (mean=0.0, std=1.0)
Expand source code
class GaussianNoise(object):
    """
    Adds Gaussian noise to each image in a batch.

    Attributes:
        mean (float): Mean of the Gaussian noise.
        std (float): Standard deviation of the Gaussian noise.
    """

    def __init__(self, mean=0., std=1.):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Adds Gaussian noise to the input tensor.

        Parameters:
            tensor (Tensor): Batch of images with shape [B, C, H, W].

        Returns:
            Tensor: Batch of images with noise added.
        """
        noise = torch.randn_like(tensor).to(tensor.device) * self.std + self.mean
        return tensor + noise

    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

Adds Gaussian noise to each image in a batch.

Attributes

mean : float
Mean of the Gaussian noise.
std : float
Standard deviation of the Gaussian noise.
class OnesMask (p=0.5)
Expand source code
class OnesMask:
    """
    Randomly sets the mask channel of an image to ones.

    Attributes:
        p (float): Probability of setting the mask to ones.
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, tensor):
        """
        Applies setting of the mask channel to ones based on probability.

        Parameters:
            tensor (Tensor): Batch of images with shape [B, 5, H, W].

        Returns:
            Tensor: Batch with ones in the mask channel.
        """
        for i in range(tensor.shape[0]):
            if random.random() < self.p:
                tensor[i, 4] = torch.ones_like(tensor[i, 4])
        return tensor

Randomly sets the mask channel of an image to ones.

Attributes

p : float
Probability of setting the mask to ones.
class RandomErodeDilateTransform (kernel_size=5, iterations=1)
Expand source code
class RandomErodeDilateTransform:
    """
    Applies random erosion or dilation to the mask channel of a batch.

    Attributes:
        kernel_size (int): Size of the circular kernel.
        iterations (int): Number of erosion/dilation iterations.
    """

    def __init__(self, kernel_size=5, iterations=1):
        self.kernel_size = kernel_size
        self.iterations = iterations

    def __call__(self, tensor):
        """
        Applies erosion or dilation randomly to the mask channel.

        Parameters:
            tensor (Tensor): Batch of images with shape [B, 5, H, W].

        Returns:
            Tensor: Batch of processed images.
        """
        batch_size = tensor.shape[0]
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.kernel_size, self.kernel_size))
        for i in range(batch_size):
            np_array = (tensor[i, 4].cpu().numpy() * 255).astype(np.uint8)
            if random.random() < 0.5:
                processed_np_array = cv2.dilate(np_array, kernel, iterations=self.iterations)
            else:
                processed_np_array = cv2.erode(np_array, kernel, iterations=self.iterations)
            processed_tensor = torch.from_numpy(processed_np_array.astype(np.float32) / 255).to(tensor.device)
            tensor[i, 4] = processed_tensor
        return tensor

Applies random erosion or dilation to the mask channel of a batch.

Attributes

kernel_size : int
Size of the circular kernel.
iterations : int
Number of erosion/dilation iterations.
class ZeroMask (p=0.5)
Expand source code
class ZeroMask:
    """
    Randomly sets the mask channel of an image to zero.

    Attributes:
        p (float): Probability of setting the mask to zero.
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, tensor):
        """
        Applies zeroing of the mask channel based on probability.

        Parameters:
            tensor (Tensor): Batch of images with shape [B, 5, H, W].

        Returns:
            Tensor: Batch with zeroed masks.
        """
        for i in range(tensor.shape[0]):
            if random.random() < self.p:
                tensor[i, 4] = torch.zeros_like(tensor[i, 4])
        return tensor

Randomly sets the mask channel of an image to zero.

Attributes

p : float
Probability of setting the mask to zero.