Module pipeline.src.pipeline

Functions

def extract_features(frame, params)
Expand source code
def extract_features(frame, params):
    """
    Extract features from a segmented frame.

    Parameters:
        frame (Frame object): The input frame containing image data.
        params (dict): Dictionary containing feature extraction parameters, including:
            - filters (list): Filters to apply during feature extraction.
            - extract_img (bool): Flag to indicate whether to extract cropped images.
            - width (int): Width of the extracted image crop.
            - mask_flag (bool): Whether to include masks during crop extraction.

    Returns:
        dict: A dictionary containing:
            - "features": DataFrame of calculated features.
            - "images": Array of cropped event images (if extracted).
            - "masks": Array of cropped event masks (if extracted).
    """
    # Calculate basic features from the frame
    features = frame.calc_basic_features()

    # Initialize variables to store images and masks
    images = None
    masks = None

    # Apply event filtering if specified
    if len(params["filters"]) != 0:
        features = utils.filter_events(
            features, params["filters"], params["verbose"]
        )

    # Extract images and masks if feature extraction is enabled
    if features is not None:
        if params["extract_img"]:
            images, masks = frame.extract_crops(
                features, params["width"], mask_flag=params["mask_flag"]
            )

    return {"features": features, "images": images, "masks": masks}

Extract features from a segmented frame.

Parameters

frame (Frame object): The input frame containing image data. params (dict): Dictionary containing feature extraction parameters, including: - filters (list): Filters to apply during feature extraction. - extract_img (bool): Flag to indicate whether to extract cropped images. - width (int): Width of the extracted image crop. - mask_flag (bool): Whether to include masks during crop extraction.

Returns

dict
A dictionary containing: - "features": DataFrame of calculated features. - "images": Array of cropped event images (if extracted). - "masks": Array of cropped event masks (if extracted).
def main()
Expand source code
def main():

    # main inputs
    parser = argparse.ArgumentParser(
        description="Process slide images to identify cells.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "-i",
        "--input",
        type=str,
        required=True,
        help="input path to slide images",
    )

    parser.add_argument(
        "-o", "--output", type=str, required=True, help="output path"
    )

    parser.add_argument(
        "-m",
        "--mask_path",
        type=str,
        default=None,
        help="mask path to save frame masks if needed",
    )

    parser.add_argument(
        "-f", "--offset", type=int, default=0, help="start frame offset"
    )

    parser.add_argument(
        "-n", "--nframes", type=int, default=2304, help="number of frames"
    )

    parser.add_argument(
        "-c",
        "--channels",
        type=str,
        nargs="+",
        default=["DAPI", "TRITC", "CY5", "FITC"],
        help="channel names",
    )

    parser.add_argument(
        "-s",
        "--starts",
        type=int,
        nargs="+",
        default=[1, 2305, 4609, 9217],
        help="channel start indices",
    )

    parser.add_argument(
        "-F",
        "--format",
        type=str,
        nargs="+",
        default=["Tile%06d.tif"],
        help="image name format",
    )

    parser.add_argument(
        "--encoder_model",
        required=True,
        type=str,
        help="path to model for inference",
    )

    parser.add_argument(
        "--mask_model",
        required=True,
        type=str,
        help="path to model for segmentation model",
    )

    parser.add_argument(
        "-v", "--verbose", action="count", default=0, help="verbosity level"
    )

    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        default=0,
        help="number of threads for parallel processing",
    )

    # Segmentation parameters
    parser.add_argument(
        "--tophat_size", type=int, default=45, help="TopHat filter kernel size"
    )

    parser.add_argument(
        "--open_size",
        type=int,
        default=5,
        help="Open morphological filter kernel size",
    )

    parser.add_argument(
        "--mask_channels",
        type=str,
        nargs="+",
        default=["DAPI", "TRITC", "CY5", "FITC"],
        help="channels to segment",
    )

    parser.add_argument(
        "--exclude_border",
        default=False,
        action="store_true",
        help="exclude events that are on image borders",
    )

    parser.add_argument(
        "--include_edge_frames",
        default=False,
        action="store_true",
        help="include frames that are on the edge of slide",
    )

    parser.add_argument(
        "--selected_frames",
        type=int,
        nargs="*",
        default=[],
        help="list of selected frames to be processed",
    )

    parser.add_argument(
        "--extract_images",
        default=True,
        action="store_true",
        help="extract images of detected events for inference [always true]",
    )

    parser.add_argument(
        "-w",
        "--width",
        type=int,
        default=75,
        help="""
        size of the event images to be cropped from slide images [always 75]
        """,
    )

    parser.add_argument(
        "--inference_batch",
        type=int,
        default=10000,
        help="inference batch size",
    )

    parser.add_argument(
        "--frame_batch",
        type=int,
        default=200,
        help="frame processing batch size",
    )

    parser.add_argument(
        "--mask_flag",
        default=True,
        action="store_true",
        help="store event masks when extracting images [always True]",
    )

    parser.add_argument(
        "--sort",
        type=str,
        nargs=2,
        action="append",
        default=[],
        help="""
        sort events based on feature values.

        Usage:   <command> --sort <feature> <order>
        Example: <command> --sort TRITC_mean I
        order:    I: Increasing / D: Decreasing
        """,
    )

    parser.add_argument(
        "--filter",
        type=str,
        nargs=3,
        action="append",
        default=[],
        help="""
        feature range for filtering detected events.

        Usage:    <command> --feature_range <feature> <min> <max>
        Example:  <command> --feature_range DAPI_mean 0 10000
        """,
    )

    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="activate debug mode to save hdf5 files",
    )

    parser.add_argument(
        "--normalize",
        type=str,
        default=None,
        help="input path to h5 reference file for normalization",
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cuda:1",
        help="device to use for segmentation and inference. set to cpu if not GPU available. set to cuda:0, cuda:1, etc. for specific GPU. set to mps for mac.",
    )

    parser.add_argument(
        "--workers",
        type=int,
        default=16,
        help="number of workers for cellpose segmentation. Limited with memory, so monitor your usage.",
    )

    parser.add_argument(
        "--classifier_model",
        type=str,
        default=None,
        help="path to model for classification",
    )

    parser.add_argument(
        "--n_classes",
        type=int,
        default=0,
        help="number of classes in the model",
    )

    parser.add_argument(
        "--scaler",
        type=str,
        default=None,
        help="path to scaler for classification",
    )

    args = parser.parse_args()

    # Check if channel names and channel indices have same length
    if len(args.channels) != len(args.starts) and len(args.channels) != len(args.format):
        print("number of channels do not match with number of starts or name formats")
        sys.exit(-1)

    process_frames(args)
def post_process_frame(frame, params)
Expand source code
def post_process_frame(frame, params):
    """
    Post-process a frame after segmentation, primarily for normalization.

    Parameters:
        frame (Frame object): The frame to be post-processed.
        params (dict): Dictionary containing normalization settings.

    Returns:
        np.ndarray: The normalized image after post-processing.
    """
    # Perform histogram matching to a reference image for normalization
    image = hist_utils.match_image_to_reference(
        frame, params, params['normalize']
    )

    return image

Post-process a frame after segmentation, primarily for normalization.

Parameters

frame (Frame object): The frame to be post-processed. params (dict): Dictionary containing normalization settings.

Returns

np.ndarray
The normalized image after post-processing.
def preprocess_frame(frame, params)
Expand source code
def preprocess_frame(frame, params):
    """
    Preprocess frame images for segmentation and feature extraction.

    Parameters:
        frame (Frame object): The input frame containing image data.
        params (dict): Dictionary containing preprocessing parameters, including:
            - tophat_size (int): Size of the structuring element for the tophat operation.
            - mask_ch (list): List of channels to apply the tophat transformation.

    Modifies:
        frame.image (np.ndarray): Applies tophat transformation to specified channels.
    """
    # Apply tophat morphological operation if size is specified
    if params["tophat_size"] != 0:
        tophat_kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, (params["tophat_size"], params["tophat_size"])
        )

        # Apply tophat transformation to specified channels
        for ch in params["mask_ch"]:
            i = frame.get_ch(ch)  # Get the channel index
            frame.image[..., i] = cv2.morphologyEx(
                frame.image[..., i], cv2.MORPH_TOPHAT, tophat_kernel
            )

Preprocess frame images for segmentation and feature extraction.

Parameters

frame (Frame object): The input frame containing image data. params (dict): Dictionary containing preprocessing parameters, including: - tophat_size (int): Size of the structuring element for the tophat operation. - mask_ch (list): List of channels to apply the tophat transformation.

Modifies

frame.image (np.ndarray): Applies tophat transformation to specified channels.

def process_frames(args)
Expand source code
def process_frames(args):

    """
    Process multiple frames for segmentation, feature extraction, and analysis. 
    This is the core function that orchestrates the entire pipeline.
    It handles loading models, reading images, segmenting cells, extracting features

    Parameters:
        args (Namespace): Argument parser namespace containing the following:
            - input (str): Path to the input data directory.
            - output (str): Path to save the processed results.
            - nframes (int): Number of frames to process.
            - threads (int): Number of parallel processing threads.
            - workers (int): Number of worker processes for segmentation.
            - encoder_model (str): Path to the encoder model for feature extraction.
            - mask_model (str): Path to the Cellpose model for segmentation.
            - device (str): Computation device ('cpu' or 'cuda').
            - other processing and model parameters.
    """

    mp.set_start_method('spawn', force=True)

    # Create logger object
    logger = utils.get_logger(__name__, args.verbose)

    # Input variables
    in_path = args.input
    output = args.output
    n_frames = args.nframes
    channels = args.channels
    starts = args.starts
    offset = args.offset
    name_format = args.format
    n_threads = args.threads
    include_edge = args.include_edge_frames

    # Segmentation parameters
    params = {
        "tophat_size": args.tophat_size,
        "opening_size": args.open_size,
        #"blur_size": args.blur_size,
        #"blur_sigma": args.blur_sigma,
        #"thresh_size": args.thresh_size,
        #"thresh_offset": args.thresh_offsets,
        #"min_dist": args.min_seed_dist,
        #"seed_ch": args.seed_channel,
        "mask_ch": args.mask_channels,
        "mask_path": args.mask_path,
        "name_format": args.format,
        "exclude_border": args.exclude_border,
        "filters": args.filter,
        "extract_img": args.extract_images,
        "width": args.width,
        "mask_flag": args.mask_flag,
        "verbose": args.verbose,
        "normalize": args.normalize,
        "channel_names": args.channels,
        "debug": args.debug,
    }

    logger.info("Detecting available GPUs...")
    n_gpus = torch.cuda.device_count()
    #logger.info(f"Number of GPUs available: {n_gpus}")

    if n_gpus == 0 and args.device != 'cpu':
        logger.error("No GPUs detected. Exiting.")
        sys.exit(-1)
    #add a check if the device is within the range of available GPUs
    if args.device != 'cpu':
        if args.device not in [f'cuda:{i}' for i in range(n_gpus)] and args.device != 'mps' and args.device != 'cuda':
            logger.error("The device specified is not within the range of available GPUs. Exiting.")
            sys.exit(-1)
    

    logger.info("Loading Cellpose model to GPU...")
    cellpose_model = models.CellposeModel(gpu=True, pretrained_model=args.mask_model,device=torch.device(args.device))
    logger.info("Finished loading Cellpose model.")

    logger.info("Loading encoder model to GPU...")
    model = load_model(args.encoder_model, device=args.device).to(args.device).eval()
    logger.info("Finished loading encoder model.")


    logger.info("Loading frames...")
    # Check if there is a selection of frames to process
    if args.selected_frames:
        frame_ids = args.selected_frames
    else:
        frame_ids = [i + offset + 1 for i in range(n_frames)]

    # Read and preprocess frames in parallel
    n_proc = n_threads if n_threads > 0 else mp.cpu_count()


    read_preprocess_partial = partial(
        read_and_preprocess_frame,
        in_path=in_path,
        channels=channels,
        starts=starts,
        name_format=name_format,
        include_edge=include_edge,
        params=params
    )

    with mp.Pool(n_proc) as pool:
        frames = pool.map(read_preprocess_partial, frame_ids)

    # Filter out None frames (edge frames)
    frames = [frame for frame in frames if frame is not None]

    logger.info("Finished loading and preprocessing frames.")


    logger.info("Segmenting frames...")


    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        masks = list(tqdm.tqdm(executor.map(segment_frame, [(frame, cellpose_model, params) for frame in frames])))

    # Removing edge events
    for i, mask in enumerate(masks):
        labels = np.unique(np.concatenate([mask[0,:], mask[-1,:], mask[:,0], mask[:,-1]]))
        for label in labels:
            masks[i][mask == label] = 0

    for frame, mask in zip(frames, masks):
        frame.mask = mask.astype("uint16")
        # Saving the mask
        if params["mask_path"] is not None:
            frame.writeMask(params["mask_path"])

    del masks, cellpose_model
    torch.cuda.empty_cache()
    gc.collect()
    
    logger.info("Finished segmenting frames.")


    if args.normalize is not None:
        # Post-process frames using multiprocessing pool
        logger.info("Post-processing frames...")
        post_process_partial = partial(post_process_frame, params=params)

        with mp.Pool(n_proc) as pool:
            norm_images = list(pool.map(post_process_partial, frames))
        
        for frame, n_image in zip(frames, norm_images):
            frame.image = n_image

        del norm_images

        logger.info("Finished post-processing frames.")

    all_features = []
    if args.debug:
        all_images = []
        all_masks = []

    while frames:
        print(f"{len(frames)} frames remaining...")

        chunk, frames = frames[:args.frame_batch], frames[args.frame_batch:]
        
        logger.info("Processing the frames...")
        n_proc = 8

        # Use a context manager to ensure the pool is closed
        with mp.Pool(n_proc) as pool:
            data = pool.map(partial(extract_features, params=params), chunk)


        logger.info("Finished processing the frames.")

        del chunk
        gc.collect()

        logger.info("Collecting features...")
        features = [
            out["features"] for out in data if out["features"] is not None
        ]
        if len(features)==0:
            logger.error("No events to report in this set of frames!")
            sys.exit(-1)
        else:
            features = pd.concat(features, ignore_index=True)
        
        images = None
        masks = None

        if args.extract_images:
            logger.info("Collecting event images...")
            images = np.concatenate(
                [out["images"] for out in data if out["images"] is not None],
                axis=0
            )

            if args.mask_flag:
                logger.info("Collecting event masks...")
                masks = np.concatenate(
                    [out["masks"] for out in data if out["masks"] is not None],
                    axis=0
                )
        
        del data
        gc.collect()
        
        # Applying the input sortings
        if len(args.sort) != 0:
            logger.info("Sorting events...")
            features = utils.sort_events(features, args.sort, args.verbose)
            logger.info("Finished sorting events.")
            images = images[list(features.index)]
            masks = masks[list(features.index)]
            features.reset_index(drop=True, inplace=True)
        
        #Infer latent embeddings
        logger.info("Inferring latent features of events...")
        dataset = CustomImageDataset(images, masks, labels=np.zeros(images.shape[0]), tran=False)

        dataloader = DataLoader(dataset, batch_size=args.inference_batch, shuffle=False)
        embeddings = get_embeddings(model, dataloader, args.device)
        #convert embeddings to numpy array
        embeddings = embeddings.numpy()

        embeddings_df = pd.DataFrame(
            embeddings.astype('float16'),
            columns=[f'z{i}' for i in range(embeddings.shape[1])])
        features = pd.concat([features, embeddings_df], axis=1)
        logger.info("Finished inferring latent features.")

        if args.debug:
            all_images.append(images)
            all_masks.append(masks)

        all_features.append(features)
        
        del dataset, dataloader, embeddings, embeddings_df, images, masks, features
        gc.collect()

    all_features = pd.concat(all_features, axis=0)
    all_features.astype(basic_features_dtypes)

    logger.info("Finished processing all frames.")
    logger.info("Saving data...")
    all_features.to_parquet(output, compression='gzip')

    if args.debug:
        debug_filename = f"{os.path.dirname(output)}/{os.path.basename(output).split('.')[0]}.hdf5"
        with h5py.File(debug_filename, mode='w') as hf:
            hf.create_dataset("images", data=np.concatenate(all_images, axis=0))
            hf.create_dataset("channels", data=args.channels)
            if args.mask_flag:
                hf.create_dataset("masks", data=np.concatenate(all_masks, axis=0))
        all_features.to_hdf(debug_filename, mode='a', key='features')

    logger.info("Finished saving features.")

Process multiple frames for segmentation, feature extraction, and analysis. This is the core function that orchestrates the entire pipeline. It handles loading models, reading images, segmenting cells, extracting features

Parameters

args (Namespace): Argument parser namespace containing the following: - input (str): Path to the input data directory. - output (str): Path to save the processed results. - nframes (int): Number of frames to process. - threads (int): Number of parallel processing threads. - workers (int): Number of worker processes for segmentation. - encoder_model (str): Path to the encoder model for feature extraction. - mask_model (str): Path to the Cellpose model for segmentation. - device (str): Computation device ('cpu' or 'cuda'). - other processing and model parameters.

def read_and_preprocess_frame(frame_id, in_path, channels, starts, name_format, include_edge, params)
Expand source code
def read_and_preprocess_frame(frame_id, in_path, channels, starts, name_format, include_edge, params):
    """
    Read and preprocess a single frame for segmentation and analysis.

    Parameters:
        frame_id (int): Unique identifier for the frame.
        in_path (str): Input directory containing image files.
        channels (list): List of channel names or indices.
        starts (list): Starting points for reading the image.
        name_format (str): Naming format of the image files.
        include_edge (bool): Whether to include edge frames.
        params (dict): Dictionary containing preprocessing parameters.

    Returns:
        Frame object or None: The preprocessed frame, or None if discarded.
    """
    # Generate paths for the frame tiles
    paths = utils.generate_tile_paths(
        path=in_path,
        frame_id=frame_id,
        starts=starts,
        name_format=name_format,
    )

    # Create a Frame object with the specified channels and paths
    frame = Frame(frame_id=frame_id, channels=channels, paths=paths)

    # Exclude edge frames if specified
    if not include_edge and frame.is_edge():
        return None

    # Read the image data into the frame
    frame.readImage()

    # Apply preprocessing to the frame
    preprocess_frame(frame, params)

    return frame

Read and preprocess a single frame for segmentation and analysis.

Parameters

frame_id (int): Unique identifier for the frame. in_path (str): Input directory containing image files. channels (list): List of channel names or indices. starts (list): Starting points for reading the image. name_format (str): Naming format of the image files. include_edge (bool): Whether to include edge frames. params (dict): Dictionary containing preprocessing parameters.

Returns

Frame object or None
The preprocessed frame, or None if discarded.
def segment_frame(args)
Expand source code
def segment_frame(args):
    """
    Segment a single frame using a pre-trained Cellpose model.

    Parameters:
        args (tuple): A tuple containing:
            - frame (Frame object): The input frame for segmentation.
            - cp_model (CellposeModel): Pre-trained Cellpose model for mask generation.
            - params (dict): Dictionary containing segmentation parameters.

    Returns:
        np.ndarray: The binary mask generated by the model.
    """
    frame, cp_model, params = args

    # Convert image channels to BGR format for segmentation
    rgb = utils.channels_to_bgr(frame.image, [0, 3], [2, 3], [1, 3])

    # Perform segmentation using the Cellpose model
    mask, _, _ = cp_model.eval(
        rgb, diameter=15, channels=[0, 0], batch_size=8
    )

    return mask

Segment a single frame using a pre-trained Cellpose model.

Parameters

args (tuple): A tuple containing: - frame (Frame object): The input frame for segmentation. - cp_model (CellposeModel): Pre-trained Cellpose model for mask generation. - params (dict): Dictionary containing segmentation parameters.

Returns

np.ndarray
The binary mask generated by the model.