Module src.leukocyte_classifier.train_wbc

Functions

def main()
Expand source code
def main():
    # Hyperparameters
    num_epochs = 25
    batch_size = 32
    learning_rate = 1e-6

    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    
    # Get data loaders
    train_loader, val_loader = get_data_loaders(data_path="/mnt/deepstore/LBxPheno/train_data/wbc_classifier/processed")

    dataloaders = {'train': train_loader, 'val': val_loader}

    # Initialize model, criterion, and optimizer
    model = CNNModel().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    model = train_model(model, dataloaders, criterion, optimizer, num_epochs, device)

    # Save the trained model
    torch.save(model.state_dict(), '/mnt/deepstore/LBxPheno/output/wbc_classifier/wbc_model.pth')

    #save data loaders
    torch.save(train_loader, '/mnt/deepstore/LBxPheno/output/wbc_classifier/train_loader.pth')
    torch.save(val_loader, '/mnt/deepstore/LBxPheno/output/wbc_classifier/val_loader.pth')
def train_model(model, dataloaders, criterion, optimizer, num_epochs, device)
Expand source code
def train_model(model, dataloaders, criterion, optimizer, num_epochs, device):
    """
    Trains a neural network model for a specified number of epochs and evaluates it on the validation set.

    Parameters:
        model (torch.nn.Module): The neural network model to train.
        dataloaders (dict): Dictionary containing 'train' and 'val' DataLoaders.
        criterion (torch.nn.Module): Loss function to minimize.
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        num_epochs (int): Number of epochs to train the model.
        device (str or torch.device): Device on which to perform training ('cpu' or 'cuda').

    Returns:
        model (torch.nn.Module): The trained model with the best weights.
    """
    # Initialize best model weights and accuracy
    best_model_wts = model.state_dict()
    best_acc = 0.0

    # Loop over the specified number of epochs
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 30)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            running_loss = 0.0
            running_corrects = 0

            # Set model mode based on phase
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = dataloaders['train']
            else:
                model.eval()   # Set model to evaluation mode
                dataloader = dataloaders['val']

            # Iterate over the data in the current phase
            for inputs, labels in dataloader:
                # Move data to the specified device (CPU/GPU)
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass and calculate loss
                with torch.set_grad_enabled(phase == 'train'):  # Enable gradients only in training phase
                    outputs = model(inputs)  # Perform forward pass
                    _, preds = torch.max(outputs, 1)  # Get predictions
                    loss = criterion(outputs, labels)  # Compute loss

                    # Backward pass and optimization (only in training phase)
                    if phase == 'train':
                        loss.backward()  # Backpropagation
                        optimizer.step()  # Update model weights

                # Update running loss and correct predictions count
                running_loss += loss.item() * inputs.size(0)  # Multiply by batch size
                running_corrects += torch.sum(preds == labels.data)  # Count correct predictions

            # Calculate average loss and accuracy for the epoch
            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            # Print epoch statistics
            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Update the best model weights if validation accuracy improves
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    print(f'Best Validation Accuracy: {best_acc:.4f}')

    # Load the model with the best weights
    model.load_state_dict(best_model_wts)
    return model

Trains a neural network model for a specified number of epochs and evaluates it on the validation set.

Parameters

model (torch.nn.Module): The neural network model to train. dataloaders (dict): Dictionary containing 'train' and 'val' DataLoaders. criterion (torch.nn.Module): Loss function to minimize. optimizer (torch.optim.Optimizer): Optimizer for updating model weights. num_epochs (int): Number of epochs to train the model. device (str or torch.device): Device on which to perform training ('cpu' or 'cuda').

Returns

model (torch.nn.Module): The trained model with the best weights.