Module src.leukocyte_classifier.wbc_dataloader

Functions

def get_data_loaders(data_path, batch_size=64)
Expand source code
def get_data_loaders(data_path, batch_size=64):
    """
    Creates training and validation data loaders from HDF5 files.
    The function reads data from subdirectories, handles class labeling,
    downsamples if necessary, splits into training and validation sets,
    and returns data loaders for both sets.

    Parameters:
        data_path (str): Path to the directory containing subfolders with HDF5 files.
        batch_size (int): Number of samples per batch for the data loaders (default: 64).

    Returns:
        Tuple[DataLoader, DataLoader]: 
            - train_loader: DataLoader object for the training set.
            - val_loader: DataLoader object for the validation set.
    """
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    # Get list of subdirectories representing data types (classes)
    types = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))]

    # Initialize lists to store training and validation images, masks, and labels
    train_images_list, val_images_list = [], []
    train_masks_list, val_masks_list = [], []
    train_labels_list, val_labels_list = [], []

    print(types)  # Display the available types (classes)

    # Iterate over each subdirectory (class) and process the data
    for label, t in enumerate(types):
        print(t)

        # Assign labels based on the directory name
        if t == "wbcs":
            label = 1  # White blood cells are labeled as 1
        else:
            label = 0  # All other types are labeled as 0

        # Construct the path to the current data type directory
        current_type_path = os.path.join(data_path, t)

        # Get all HDF5 files in the current directory
        current_type_files = glob.glob(os.path.join(current_type_path, "*.hdf5"))

        # Initialize lists to store images and masks from the current class
        class_images, class_masks = [], []

        # Load images and masks from each HDF5 file
        for file_path in current_type_files:
            with h5py.File(file_path, 'r') as f:
                # Load images and masks, ensuring the correct data types
                imgs = np.array(f['images'][:], dtype=np.float32)
                msks = np.array(f['masks'][:])
                class_images.append(imgs)
                class_masks.append(msks)

        # Concatenate all images and masks from the current class
        class_images = np.concatenate(class_images, axis=0)
        class_masks = np.concatenate(class_masks, axis=0)

        # Downsample if the number of images exceeds 9149
        if len(class_images) > 9149:
            indices = np.random.choice(range(len(class_images)), int(9149), replace=False)
            class_images = class_images[indices]
            class_masks = class_masks[indices]

        # Split data into training (80%) and validation (20%) sets
        num_train = int(len(class_images) * 0.8)
        train_imgs, val_imgs = class_images[:num_train], class_images[num_train:]
        train_masks, val_masks = class_masks[:num_train], class_masks[num_train:]

        # Append training data to the respective lists
        train_images_list.append(train_imgs)
        train_masks_list.append(train_masks)
        train_labels_list.append(np.full(len(train_imgs), label, dtype=np.int64))

        print(len(train_imgs))  # Print the number of training images for the current class

        # Append validation data to the respective lists
        val_images_list.append(val_imgs)
        val_masks_list.append(val_masks)
        val_labels_list.append(np.full(len(val_imgs), label, dtype=np.int64))

    # Concatenate all training and validation data from all classes
    train_images = np.concatenate(train_images_list, axis=0)
    val_images = np.concatenate(val_images_list, axis=0)
    train_labels = np.concatenate(train_labels_list, axis=0)
    val_labels = np.concatenate(val_labels_list, axis=0)
    train_masks = np.concatenate(train_masks_list, axis=0)
    val_masks = np.concatenate(val_masks_list, axis=0)

    print(len(train_images), len(val_images))  # Print the total number of training and validation images

    # Create PyTorch datasets using the custom dataset class
    train_dataset = CustomImageDataset(train_images, train_masks, train_labels, tran=True)
    val_dataset = CustomImageDataset(val_images, val_masks, val_labels, tran=False)

    # Create DataLoaders with shuffling for training and no shuffling for validation
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=len(val_images), shuffle=False)

    return train_loader, val_loader

Creates training and validation data loaders from HDF5 files. The function reads data from subdirectories, handles class labeling, downsamples if necessary, splits into training and validation sets, and returns data loaders for both sets.

Parameters

data_path (str): Path to the directory containing subfolders with HDF5 files. batch_size (int): Number of samples per batch for the data loaders (default: 64).

Returns

Tuple[DataLoader, DataLoader]
  • train_loader: DataLoader object for the training set.
  • val_loader: DataLoader object for the validation set.

Classes

class CustomImageDataset (images, masks, labels, tran=False)
Expand source code
class CustomImageDataset(Dataset):
    def __init__(self, images, masks, labels, tran=False):
        """
        Custom dataset for loading 4-channel, 75x75, 16-bit TIFF images.
        :param images: Numpy array of images.
        :param masks: Numpy array of binary masks. 
        :param labels: Numpy array of labels.
        """
        self.images = images
        self.masks = masks
        self.labels = labels
        self.tran=tran

        self.t = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Extract a single image and its label
        #image = np.log1p(self.images[idx].astype(np.float32)) / np.log(65535.0)
        image = self.images[idx].astype(np.float32) / 65535.0
        label = self.labels[idx]
        mask = self.masks[idx].astype(np.int16)
        
        image = self.t(image)
        
        mask = self.t(mask)
        hard_masked_image = image * mask
        hard_masked_image = torch.cat((hard_masked_image, mask), dim=0)

        return hard_masked_image, torch.tensor(label, dtype=torch.long)

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

:class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

Custom dataset for loading 4-channel, 75x75, 16-bit TIFF images. :param images: Numpy array of images. :param masks: Numpy array of binary masks. :param labels: Numpy array of labels.

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic