Train PyTorch with Checkpoints
Training a PyTorch model with checkpoints is an important technique to ensure that the model is saved periodically during the training process. This helps to avoid losing all of the progress made during the training if there is a power failure, hardware failure or any other unexpected problem.
To train a PyTorch model with checkpoints, you can follow these steps:
- Import the necessary libraries and packages.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
- Define your model architecture.
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
x = nn.functional.relu(x)
x = self.fc3(x)
return x
- Define your loss function and optimizer.
model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
- Define your training loop.
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device, save_path):
best_acc = 0.0
for epoch in range(epochs):
running_loss = 0.0
running_corrects = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
total += labels.size(0)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
train_acc = running_corrects.double() / total
val_loss, val_acc = evaluate(model, criterion, val_loader, device)
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), save_path)
return best_acc
The train function takes in the model, criterion, optimizer, train_loader, val_loader, epochs, device and save_path as input. Inside the function, we iterate over the epochs and the data in the train_loader. We calculate the loss and the accuracy and update the model parameters. At the end of each epoch, we evaluate the model on the validation data and save the model if the validation accuracy is better than the best accuracy so far.
- Define your evaluation loop.
def evaluate(model, criterion, dataloader, device):
running_loss = 0.0
running_corrects = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
running_corrects
+= torch.sum(preds == labels.data)
total += labels.size(0)
running_loss += loss.item()
loss = running_loss / len(dataloader)
acc = running_corrects.double() / total
return loss, acc
- Finally, you can call the
trainfunction with your desired parameters.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion.to(device)
epochs = 10
save_path = "my_model.pth"
best_acc = train(model, criterion, optimizer, train_loader, val_loader, epochs, device, save_path)
print(f"Best validation accuracy: {best_acc:.4f}")
In this example, we are using a batch size of 32 and training the model for 10 epochs. We are also saving the model to a file called “my_model.pth”. The best validation accuracy achieved during the training process is printed at the end.
By using checkpoints, you can ensure that you can resume training from where you left off in case of any interruption. You can load the saved model using model.load_state_dict(torch.load(save_path)) and continue training the model as normal.
Comments