Module src.representation_learning.train_cl
Classes
class Trainer (config)
-
Expand source code
class Trainer(object): """ Trainer class to manage the training, validation, and model saving process. Uses WandB for experiment tracking and logging. Attributes: main_config (dict): Configuration dictionary containing training settings. best_pred (float): Best validation loss observed during training. best_epoch (int): Epoch at which the best validation loss occurred. model (torch.nn.Module): The neural network model to be trained. optimizer (torch.optim.Optimizer): Optimizer for training. scheduler (torch.optim.lr_scheduler): Learning rate scheduler. train_loader (DataLoader): DataLoader for the training set. val_loader (DataLoader): DataLoader for the validation set. Methods: run_training(): Executes the entire training pipeline. build_dataset(config): Loads the training and validation data. build_optimizer(config): Sets up the optimizer based on configuration. build_scheduler(config): Configures the learning rate scheduler. build_model(config): Instantiates the model to be trained. train(config): Trains the model for the specified number of epochs. validate(temp): Evaluates the model on the validation dataset. save_checkpoint(model, epoch, dir): Saves the model checkpoint. """ def __init__(self, config): """ Initializes the Trainer object. Parameters: config (dict): Configuration dictionary with model and training parameters. """ self.main_config = config def run_training(self): """ Runs the complete training process including dataset loading, model initialization, optimizer configuration, scheduler setup, and model training. """ with wandb.init(): # Set random seed for reproducibility if self.main_config['random_seed'] is not None: np.random.seed(self.main_config['random_seed']) torch.manual_seed(self.main_config['random_seed']) torch.cuda.manual_seed(self.main_config['random_seed']) self.best_pred = np.inf # Initialize best loss to infinity self.best_epoch = 0 # Track the best epoch # Build the dataset, model, optimizer, and scheduler self.build_dataset(config=wandb.config) self.build_model(config=wandb.config) self.build_optimizer(config=wandb.config) self.build_scheduler(config=wandb.config) # Start training self.train(config=wandb.config) def build_dataset(self, config): """ Builds the training and validation datasets. Parameters: config (dict): Configuration dictionary from WandB. """ print(config) self.train_loader, self.val_loader = get_data_loaders( data_path=self.main_config['data_path'], batch_size=config.batch_size ) def build_optimizer(self, config): """ Initializes the optimizer based on the configuration. Parameters: config (dict): Configuration dictionary specifying optimizer type and parameters. """ if config.optimizer == "sgd": self.optimizer = torch.optim.SGD( params=self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay, momentum=config.momentum ) elif config.optimizer == "adam": self.optimizer = torch.optim.Adam( params=self.model.parameters(), lr=config.max_lr, weight_decay=config.weight_decay ) def build_scheduler(self, config): """ Sets up the learning rate scheduler based on the specified type. Parameters: config (dict): Configuration dictionary with scheduler settings. """ # Scheduler with linear warmup and exponential decay if config.scheduler == 'LambdaLR': lr_multiplier = config.max_lr / config.base_lr self.scheduler = LambdaLR( optimizer=self.optimizer, lr_lambda=lambda epoch: ( ((lr_multiplier - 1) * epoch / config.l_e + 1) if epoch < config.l_e else lr_multiplier * (config.l_b ** (epoch - config.l_e)) ) ) elif config.scheduler == "Cyclic": self.scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr, step_size_up=72, mode='exp_range', gamma=0.96, scale_mode='cycle', cycle_momentum=False ) def build_model(self, config): """ Instantiates the model based on the configuration. Parameters: config (dict): Configuration dictionary with model parameters. """ self.model = Model( in_channels=config.in_channels, h_dim=config.h_dim, projection_dim=config.projection_dim ).to(self.main_config['device']) def train(self, config): """ Trains the model for a specified number of epochs. Parameters: config (dict): Configuration dictionary specifying training parameters. """ for epoch in range(config.epochs): self.model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(self.train_loader): data = data.to(self.main_config['device']) self.optimizer.zero_grad() # Forward pass and loss calculation z_i, z_j, _, _ = self.model(data) loss = self.model.loss(z_i, z_j, config.temperature) loss.backward() self.optimizer.step() train_loss += loss.item() data = data.detach().cpu() # Free memory # Update scheduler and log training loss self.scheduler.step() train_loss /= len(self.train_loader) print(f"Epoch {epoch} Loss: {train_loss}", end="\t") wandb.log({"train_loss": train_loss}) # Early stop if loss becomes NaN if np.isnan(train_loss): print("Training loss is NaN, stopping.") break # Validate the model val_loss, _ = self.validate(config.temperature) wandb.log({"val_loss": val_loss}) # Save the best model if epoch == 49: self.best_pred = val_loss self.best_epoch = epoch self.save_checkpoint(self.model, epoch, self.main_config['model_path']) def validate(self, temp): """ Validates the model on the validation dataset. Parameters: temp (float): Temperature parameter for the contrastive loss. Returns: Tuple[float, list]: Validation loss and list of latent representations. """ self.model.eval() val_loss = 0 h_i_list = [] for batch_idx, (data, label) in enumerate(self.val_loader): with torch.no_grad(): data = data.to(self.main_config['device']) z_i, z_j, _, _ = self.model(data) loss = self.model.loss(z_i, z_j, temp) val_loss += loss.item() # Extract latent representations h_i = self.model.encoder(data).detach().cpu().numpy() label = label.cpu().numpy() for i in range(h_i.shape[0]): h_data = {str(j): h_i[i][j] for j in range(self.model.h_dim)} h_data["label"] = str(int(label[i])) h_i_list.append(h_data) data = data.detach().cpu() val_loss /= len(self.val_loader) print(f"Validation Loss: {val_loss}") return val_loss, h_i_list def save_checkpoint(self, model, epoch, dir): """ Saves the model checkpoint. Parameters: model (torch.nn.Module): The model to save. epoch (int): The current epoch number. dir (str): Directory to save the checkpoint file. """ checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() } fname = os.path.join(dir, f"model_ep_{epoch}_loss_{self.best_pred:.4f}.pth") torch.save(checkpoint, fname)
Trainer class to manage the training, validation, and model saving process. Uses WandB for experiment tracking and logging.
Attributes
main_config
:dict
- Configuration dictionary containing training settings.
best_pred
:float
- Best validation loss observed during training.
best_epoch
:int
- Epoch at which the best validation loss occurred.
model
:torch.nn.Module
- The neural network model to be trained.
optimizer
:torch.optim.Optimizer
- Optimizer for training.
scheduler
:torch.optim.lr_scheduler
- Learning rate scheduler.
train_loader
:DataLoader
- DataLoader for the training set.
val_loader
:DataLoader
- DataLoader for the validation set.
Methods
run_training(): Executes the entire training pipeline. build_dataset(config): Loads the training and validation data. build_optimizer(config): Sets up the optimizer based on configuration. build_scheduler(config): Configures the learning rate scheduler. build_model(config): Instantiates the model to be trained. train(config): Trains the model for the specified number of epochs. validate(temp): Evaluates the model on the validation dataset. save_checkpoint(model, epoch, dir): Saves the model checkpoint.
Initializes the Trainer object.
Parameters
config (dict): Configuration dictionary with model and training parameters.
Methods
def build_dataset(self, config)
-
Expand source code
def build_dataset(self, config): """ Builds the training and validation datasets. Parameters: config (dict): Configuration dictionary from WandB. """ print(config) self.train_loader, self.val_loader = get_data_loaders( data_path=self.main_config['data_path'], batch_size=config.batch_size )
Builds the training and validation datasets.
Parameters
config (dict): Configuration dictionary from WandB.
def build_model(self, config)
-
Expand source code
def build_model(self, config): """ Instantiates the model based on the configuration. Parameters: config (dict): Configuration dictionary with model parameters. """ self.model = Model( in_channels=config.in_channels, h_dim=config.h_dim, projection_dim=config.projection_dim ).to(self.main_config['device'])
Instantiates the model based on the configuration.
Parameters
config (dict): Configuration dictionary with model parameters.
def build_optimizer(self, config)
-
Expand source code
def build_optimizer(self, config): """ Initializes the optimizer based on the configuration. Parameters: config (dict): Configuration dictionary specifying optimizer type and parameters. """ if config.optimizer == "sgd": self.optimizer = torch.optim.SGD( params=self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay, momentum=config.momentum ) elif config.optimizer == "adam": self.optimizer = torch.optim.Adam( params=self.model.parameters(), lr=config.max_lr, weight_decay=config.weight_decay )
Initializes the optimizer based on the configuration.
Parameters
config (dict): Configuration dictionary specifying optimizer type and parameters.
def build_scheduler(self, config)
-
Expand source code
def build_scheduler(self, config): """ Sets up the learning rate scheduler based on the specified type. Parameters: config (dict): Configuration dictionary with scheduler settings. """ # Scheduler with linear warmup and exponential decay if config.scheduler == 'LambdaLR': lr_multiplier = config.max_lr / config.base_lr self.scheduler = LambdaLR( optimizer=self.optimizer, lr_lambda=lambda epoch: ( ((lr_multiplier - 1) * epoch / config.l_e + 1) if epoch < config.l_e else lr_multiplier * (config.l_b ** (epoch - config.l_e)) ) ) elif config.scheduler == "Cyclic": self.scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr, step_size_up=72, mode='exp_range', gamma=0.96, scale_mode='cycle', cycle_momentum=False )
Sets up the learning rate scheduler based on the specified type.
Parameters
config (dict): Configuration dictionary with scheduler settings.
def run_training(self)
-
Expand source code
def run_training(self): """ Runs the complete training process including dataset loading, model initialization, optimizer configuration, scheduler setup, and model training. """ with wandb.init(): # Set random seed for reproducibility if self.main_config['random_seed'] is not None: np.random.seed(self.main_config['random_seed']) torch.manual_seed(self.main_config['random_seed']) torch.cuda.manual_seed(self.main_config['random_seed']) self.best_pred = np.inf # Initialize best loss to infinity self.best_epoch = 0 # Track the best epoch # Build the dataset, model, optimizer, and scheduler self.build_dataset(config=wandb.config) self.build_model(config=wandb.config) self.build_optimizer(config=wandb.config) self.build_scheduler(config=wandb.config) # Start training self.train(config=wandb.config)
Runs the complete training process including dataset loading, model initialization, optimizer configuration, scheduler setup, and model training.
def save_checkpoint(self, model, epoch, dir)
-
Expand source code
def save_checkpoint(self, model, epoch, dir): """ Saves the model checkpoint. Parameters: model (torch.nn.Module): The model to save. epoch (int): The current epoch number. dir (str): Directory to save the checkpoint file. """ checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() } fname = os.path.join(dir, f"model_ep_{epoch}_loss_{self.best_pred:.4f}.pth") torch.save(checkpoint, fname)
Saves the model checkpoint.
Parameters
model (torch.nn.Module): The model to save. epoch (int): The current epoch number. dir (str): Directory to save the checkpoint file.
def train(self, config)
-
Expand source code
def train(self, config): """ Trains the model for a specified number of epochs. Parameters: config (dict): Configuration dictionary specifying training parameters. """ for epoch in range(config.epochs): self.model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(self.train_loader): data = data.to(self.main_config['device']) self.optimizer.zero_grad() # Forward pass and loss calculation z_i, z_j, _, _ = self.model(data) loss = self.model.loss(z_i, z_j, config.temperature) loss.backward() self.optimizer.step() train_loss += loss.item() data = data.detach().cpu() # Free memory # Update scheduler and log training loss self.scheduler.step() train_loss /= len(self.train_loader) print(f"Epoch {epoch} Loss: {train_loss}", end="\t") wandb.log({"train_loss": train_loss}) # Early stop if loss becomes NaN if np.isnan(train_loss): print("Training loss is NaN, stopping.") break # Validate the model val_loss, _ = self.validate(config.temperature) wandb.log({"val_loss": val_loss}) # Save the best model if epoch == 49: self.best_pred = val_loss self.best_epoch = epoch self.save_checkpoint(self.model, epoch, self.main_config['model_path'])
Trains the model for a specified number of epochs.
Parameters
config (dict): Configuration dictionary specifying training parameters.
def validate(self, temp)
-
Expand source code
def validate(self, temp): """ Validates the model on the validation dataset. Parameters: temp (float): Temperature parameter for the contrastive loss. Returns: Tuple[float, list]: Validation loss and list of latent representations. """ self.model.eval() val_loss = 0 h_i_list = [] for batch_idx, (data, label) in enumerate(self.val_loader): with torch.no_grad(): data = data.to(self.main_config['device']) z_i, z_j, _, _ = self.model(data) loss = self.model.loss(z_i, z_j, temp) val_loss += loss.item() # Extract latent representations h_i = self.model.encoder(data).detach().cpu().numpy() label = label.cpu().numpy() for i in range(h_i.shape[0]): h_data = {str(j): h_i[i][j] for j in range(self.model.h_dim)} h_data["label"] = str(int(label[i])) h_i_list.append(h_data) data = data.detach().cpu() val_loss /= len(self.val_loader) print(f"Validation Loss: {val_loss}") return val_loss, h_i_list
Validates the model on the validation dataset.
Parameters
temp (float): Temperature parameter for the contrastive loss.
Returns
Tuple[float, list]
- Validation loss and list of latent representations.