Module src.utils.utils
Functions
def channels_to_bgr(image, blue_index, green_index, red_index)
-
Expand source code
def channels_to_bgr(image, blue_index, green_index, red_index): """ Convert image channels to BGR 3-color format for visualization. Parameters: image (np.ndarray): Input image with shape (H, W, C) or (1, H, W, C). blue_index (list): Indices for blue channels. green_index (list): Indices for green channels. red_index (list): Indices for red channels. Returns: np.ndarray: Image in BGR format with combined channels. """ if len(image.shape) == 3: image = image[np.newaxis, ...] # Add batch dimension if not present # Initialize BGR array bgr = np.zeros((image.shape[0], image.shape[1], image.shape[2], 3), dtype='float') # Combine specified channels if len(blue_index) != 0: bgr[..., 0] = np.sum(image[..., blue_index], axis=-1) if len(green_index) != 0: bgr[..., 1] = np.sum(image[..., green_index], axis=-1) if len(red_index) != 0: bgr[..., 2] = np.sum(image[..., red_index], axis=-1) # Clip values to maximum allowed by dtype max_val = np.iinfo(image.dtype).max bgr[bgr > max_val] = max_val bgr = bgr.astype(image.dtype) return bgr
Convert image channels to BGR 3-color format for visualization.
Parameters
image (np.ndarray): Input image with shape (H, W, C) or (1, H, W, C). blue_index (list): Indices for blue channels. green_index (list): Indices for green channels. red_index (list): Indices for red channels.
Returns
np.ndarray
- Image in BGR format with combined channels.
def get_data_loaders(data_path, batch_size=64)
-
Expand source code
def get_data_loaders(data_path, batch_size=64): """ Create a DataLoader from a given directory containing HDF5 files. Parameters: data_path (str): Path to the dataset directory. batch_size (int): Number of samples per batch. Returns: tuple: DataLoader, images, masks, labels. """ types = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))] images_list, masks_list, labels_list = [], [], [] for label, t in enumerate(types): current_type_path = os.path.join(data_path, t) current_type_files = glob.glob(os.path.join(current_type_path, "*.hdf5")) class_images, class_masks = [], [] for file_path in current_type_files: with h5py.File(file_path, 'r') as f: imgs = np.array(f['images'][:]) msks = np.array(f['masks'][:]) class_images.append(imgs) class_masks.append(msks) class_images = np.concatenate(class_images, axis=0) class_masks = np.concatenate(class_masks, axis=0) class_masks[class_masks > 0] = 1 images_list.append(class_images) masks_list.append(class_masks) labels_list.append(np.full(len(class_images), label, dtype=np.int64)) images = np.concatenate(images_list, axis=0) masks = np.concatenate(masks_list, axis=0) labels = np.concatenate(labels_list, axis=0) dataset = CustomImageDataset(images, masks, labels) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) return dataloader, images, masks, labels
Create a DataLoader from a given directory containing HDF5 files.
Parameters
data_path (str): Path to the dataset directory. batch_size (int): Number of samples per batch.
Returns
tuple
- DataLoader, images, masks, labels.
def get_embeddings(model, dataloader, device)
-
Expand source code
def get_embeddings(model, dataloader, device): """ Extract embeddings from the encoder model for a given dataset. Parameters: model (torch.nn.Module): Trained model to generate embeddings. dataloader (DataLoader): DataLoader object to provide batches of images. device (str or torch.device): Device to perform computations. Returns: torch.Tensor: Concatenated embeddings from the entire dataset. """ model.eval() # Set model to evaluation mode embeddings = [] with torch.no_grad(): for x, _ in dataloader: x = x.to(device) # Move data to device embeddings.append(model.encoder(x).detach().cpu()) # Store embeddings # Concatenate all embedding tensors embeddings = torch.cat(embeddings) return embeddings
Extract embeddings from the encoder model for a given dataset.
Parameters
model (torch.nn.Module): Trained model to generate embeddings. dataloader (DataLoader): DataLoader object to provide batches of images. device (str or torch.device): Device to perform computations.
Returns
torch.Tensor
- Concatenated embeddings from the entire dataset.
def load_model(model_path, device)
-
Expand source code
def load_model(model_path, device): """ Load a pre-trained encoder model from a checkpoint file. Parameters: model_path (str): Path to the model checkpoint file. device (str or torch.device): Device to load the model onto. Returns: torch.nn.Module: Loaded model in evaluation mode. """ model = CL(in_channels=5, h_dim=128, projection_dim=64) # Instantiate the model state_dict = torch.load(model_path, map_location=device)['model_state_dict'] model.load_state_dict(state_dict) # Load model weights model.eval() # Set model to evaluation mode return model
Load a pre-trained encoder model from a checkpoint file.
Parameters
model_path (str): Path to the model checkpoint file. device (str or torch.device): Device to load the model onto.
Returns
torch.nn.Module
- Loaded model in evaluation mode.
def save_5channel_tiffs(images, masks, outfiles)
-
Expand source code
def save_5channel_tiffs(images, masks, outfiles): """ Save each image as a separate 5-channel TIFF file. Parameters: images (np.ndarray): Array of images (n, h, w, c). masks (np.ndarray): Array of masks (n, h, w, 1). outfiles (list of str): List of output file paths. """ if images.ndim != 4 or masks.ndim != 4: raise ValueError("Expected 4D arrays for images and masks.") if images.shape[:3] != masks.shape[:3]: raise ValueError("Data and mask shapes do not match.") data = np.concatenate([images, masks], axis=-1) # Combine image and mask channels for i in range(data.shape[0]): img_5ch = np.transpose(data[i], (2, 0, 1)) # Convert to (c, h, w) tifffile.imwrite(outfiles[i], img_5ch, imagej=True, metadata={"axes": "CYX"})
Save each image as a separate 5-channel TIFF file.
Parameters
images (np.ndarray): Array of images (n, h, w, c). masks (np.ndarray): Array of masks (n, h, w, 1). outfiles (list of str): List of output file paths.
def save_5channel_tiffs_single_file(images, masks, out_file)
-
Expand source code
def save_5channel_tiffs_single_file(images, masks, out_file): """ Save a batch of images and masks as a single multi-page 5-channel TIFF file. Parameters: images (np.ndarray): Array of images with shape (n, h, w, 4). masks (np.ndarray): Array of masks with shape (n, h, w, 1). out_file (str): Path to the output TIFF file. """ data_5ch = np.concatenate([images, masks], axis=-1) # Combine images and masks data_5ch = np.transpose(data_5ch, (0, 3, 1, 2)) # Convert to (n, c, h, w) # Save as a multi-page TIFF file tifffile.imwrite( out_file, data_5ch, imagej=True, metadata={"axes": "TCYX"} ) print(f"Saved multi-page TIFF: {out_file}")
Save a batch of images and masks as a single multi-page 5-channel TIFF file.
Parameters
images (np.ndarray): Array of images with shape (n, h, w, 4). masks (np.ndarray): Array of masks with shape (n, h, w, 1). out_file (str): Path to the output TIFF file.
def shuffle_index(label, p=0.5)
-
Expand source code
def shuffle_index(label, p=0.5): """ Class-informed shuffling of indices to introduce additional positive pairs from instances of the same class. This is useful for contrastive learning. Parameters: label (np.ndarray or torch.Tensor): Array of class labels. p (float): Proportion of positive pairs to generate within each class (default: 0.5). Returns: np.ndarray: Array of shuffled indices. """ index = np.arange(len(label)) # Initialize index array for l in np.unique(label): size = int(p * torch.sum(label == l)) # Calculate number of positive pairs if size >= 2: # Randomly select indices within the class t1 = np.random.choice(a=index[label == l], size=size, replace=False) t2 = t1.copy() np.random.shuffle(t2) # Shuffle the selected indices index[t1] = index[t2] # Swap positions return index
Class-informed shuffling of indices to introduce additional positive pairs from instances of the same class. This is useful for contrastive learning.
Parameters
label (np.ndarray or torch.Tensor): Array of class labels. p (float): Proportion of positive pairs to generate within each class (default: 0.5).
Returns
np.ndarray
- Array of shuffled indices.