Module src.utils.generate_masks
Functions
def generate_mask(args)
-
Expand source code
def generate_mask(args): """ Generate a binary mask for an image using a pre-trained Cellpose model. Parameters: args (tuple): A tuple containing: - image (np.ndarray): Input image. - cp_model (CellposeModel): Pre-trained Cellpose model for segmentation. Returns: np.ndarray: Binary mask of the input image. """ image, cp_model = args # Convert channels to BGR format for processing rgb = channels_to_bgr(image, [0, 3], [2, 3], [1, 3]) # Run the Cellpose model to generate a mask mask, _, _ = cp_model.eval( rgb, diameter=20, channels=[0, 0], batch_size=8 ) # Extract the center mask from the generated mask mask = get_center_mask(mask) return mask
Generate a binary mask for an image using a pre-trained Cellpose model.
Parameters
args (tuple): A tuple containing: - image (np.ndarray): Input image. - cp_model (CellposeModel): Pre-trained Cellpose model for segmentation.
Returns
np.ndarray
- Binary mask of the input image.
def get_center_mask(arr)
-
Expand source code
def get_center_mask(arr): """ Generate a binary mask where only the pixels with the same value as the center pixel are kept. Parameters: arr (np.ndarray): Input array from which to generate the mask. Returns: np.ndarray: Binary mask where the center region is preserved. """ # Step 1: Get the center pixel's coordinates and value center_x = arr.shape[0] // 2 center_y = arr.shape[1] // 2 center_value = arr[center_x, center_y] # Step 2: Create a mask for pixels equal to the center value mask = (arr == center_value) # Step 3: Set all other pixels to zero arr[~mask] = 0 # Step 4: Set all non-zero pixels to 1 arr[mask] = 1 return arr
Generate a binary mask where only the pixels with the same value as the center pixel are kept.
Parameters
arr (np.ndarray): Input array from which to generate the mask.
Returns
np.ndarray
- Binary mask where the center region is preserved.
def load_h5py(file_path)
-
Expand source code
def load_h5py(file_path): """ Load images from an HDF5 file. Parameters: file_path (str): Path to the HDF5 file. Returns: np.ndarray: Array containing the loaded images. """ with h5py.File(file_path, 'r') as f: images = f['images'][:] # Read all images from the file return images
Load images from an HDF5 file.
Parameters
file_path (str): Path to the HDF5 file.
Returns
np.ndarray
- Array containing the loaded images.
def load_model(model_path, device)
-
Expand source code
def load_model(model_path, device): """ Load a pre-trained Cellpose model for segmentation. Parameters: model_path (str): Path to the pre-trained Cellpose model file. device (str or torch.device): Device to load the model on ('cuda' or 'cpu'). Returns: CellposeModel: Loaded Cellpose model. """ cellpose_model = models.CellposeModel( gpu=True, pretrained_model=model_path, device=torch.device(device) ) return cellpose_model
Load a pre-trained Cellpose model for segmentation.
Parameters
model_path (str): Path to the pre-trained Cellpose model file. device (str or torch.device): Device to load the model on ('cuda' or 'cpu').
Returns
CellposeModel
- Loaded Cellpose model.
def main()
-
Expand source code
def main(): h5_file = '/mnt/deepstore/LBxPheno/train_data/wbc_classifier/all_rare_cells.hdf5' model_path = '/mnt/deepstore/LBxPheno/pipeline/model_weights/cellpose_model' device = 'cuda' cp_model = load_model(model_path, device) images = load_h5py(h5_file) with ProcessPoolExecutor(max_workers=16) as executor: masks = list(tqdm.tqdm(executor.map(generate_mask, [(im, cp_model) for im in images]))) masks = np.array(masks) #expand dimensions of masks masks = np.expand_dims(masks, axis=-1) #convert masks to uint16 masks = masks.astype(np.uint16) #save masks to h5py file with h5py.File(h5_file, 'a') as f: if 'masks' in f: del f['masks'] f.create_dataset('masks', data=masks)