Train PyTorch with Checkpoints

Training a PyTorch model with checkpoints is an important technique to ensure that the model is saved periodically during the training process. This helps to avoid losing all of the progress made during the training if there is a power failure, hardware failure or any other unexpected problem.

To train a PyTorch model with checkpoints, you can follow these steps:

  1. Import the necessary libraries and packages.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
  1. Define your model architecture.
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x
  1. Define your loss function and optimizer.
model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
  1. Define your training loop.
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device, save_path):
    best_acc = 0.0

    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
            total += labels.size(0)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_acc = running_corrects.double() / total

        val_loss, val_acc = evaluate(model, criterion, val_loader, device)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)

    return best_acc

The train function takes in the model, criterion, optimizer, train_loader, val_loader, epochs, device and save_path as input. Inside the function, we iterate over the epochs and the data in the train_loader. We calculate the loss and the accuracy and update the model parameters. At the end of each epoch, we evaluate the model on the validation data and save the model if the validation accuracy is better than the best accuracy so far.

  1. Define your evaluation loop.
def evaluate(model, criterion, dataloader, device):
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_corrects
            += torch.sum(preds == labels.data)
            total += labels.size(0)

            running_loss += loss.item()

    loss = running_loss / len(dataloader)
    acc = running_corrects.double() / total

    return loss, acc
  1. Finally, you can call the train function with your desired parameters.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion.to(device)

epochs = 10
save_path = "my_model.pth"

best_acc = train(model, criterion, optimizer, train_loader, val_loader, epochs, device, save_path)

print(f"Best validation accuracy: {best_acc:.4f}")

In this example, we are using a batch size of 32 and training the model for 10 epochs. We are also saving the model to a file called “my_model.pth”. The best validation accuracy achieved during the training process is printed at the end.

By using checkpoints, you can ensure that you can resume training from where you left off in case of any interruption. You can load the saved model using model.load_state_dict(torch.load(save_path)) and continue training the model as normal.

Comments

Popular posts from this blog

Slang Terms About Money

Workaround for macOS Dictionary All Tab Issue

Mathematical Objects

Essential Utilities for LaTeX Package and Class Development