How to save and load models with torch.save and torch.load in PyTorch

How to save and load models with torch.save and torch.load in PyTorch

When diving into PyTorch model persistence, torch.save and torch.load are your bread and butter. At their core, torch.save leverages Python’s pickle module to serialize objects—usually tensors, model state dictionaries, or entire models—into a binary format. This binary blob is what you write to disk, ready to be resurrected later with torch.load.

One key detail is that torch.save itself doesn’t care what you’re dumping; it’s just a thin wrapper around pickle with some PyTorch-specific hooks to handle CUDA tensors and device mappings gracefully. When you call torch.load, it unpickles the data back into memory, reconstructing the objects in their original Python form.

Because of this, the magic happens largely outside of PyTorch’s immediate control and more within Python’s serialization ecosystem. That means your saved file is not a standalone, fully self-describing artifact. It’s basically a snapshot of the Python objects, which implies you need the same code environment (or compatible) to successfully load it back.

Here’s a minimal example of saving and loading a model’s state dictionary:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# Save the state dict
torch.save(model.state_dict(), 'model_state.pth')

# Load the state dict into a new model instance
model2 = SimpleModel()
model2.load_state_dict(torch.load('model_state.pth'))

Notice how the architecture needs to be defined in code before loading the state dictionary. torch.save of the state_dict only contains the parameter tensors, no model topology or code. This is a critical distinction and the source of many newbie headaches.

Contrast this with saving the entire model object directly:

torch.save(model, 'entire_model.pth')
model3 = torch.load('entire_model.pth')

In this case, PyTorch pickles the entire model object, including its architecture as defined in the Python class. But beware: this approach is brittle. If you refactor your model class or change its location in your source tree, loading the pickled model can break spectacularly due to Python’s strict pickle import semantics.

Also, the saved file size is typically larger because it includes the entire module metadata, not just weights. This method trades portability and longevity for convenience, and it’s generally discouraged for production workflows.

Under the hood, torch.save uses a customized pickler that can serialize CUDA tensors, meaning you can save GPU tensors directly. However, when loading, you can specify the device mapping:

# Load GPU-saved model on CPU
device = torch.device('cpu')
checkpoint = torch.load('model_gpu.pth', map_location=device)

This map_location flag is a lifesaver when you want to migrate models between different hardware setups without rewriting your code.

In practical terms, always remember that torch.save + torch.load are not magical black boxes. They’re Python pickle under the hood, tailored for PyTorch objects. Understanding this helps you avoid pitfalls like incompatibilities, corrupted files from version mismatches, or device-related headaches that can cost hours of debugging time.

best practices for preserving model state and architecture

The best practice for preserving your model’s state and architecture is to separate concerns: save the model’s parameters independently from the model’s code. This means you commit your model class definitions to source control and deployment pipelines, and use state_dict for checkpointing. This approach is robust, transparent, and compatible across PyTorch versions and code changes, as long as the interface remains consistent.

When saving checkpoints during training, it’s common to bundle metadata alongside the model state. This includes the optimizer state, epoch counters, learning rate schedulers, or any other information necessary to resume training seamlessly. Use a dictionary to aggregate these components before saving:

checkpoint = {
    'epoch': epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict() if scheduler else None,
    'loss': loss_value,
}
torch.save(checkpoint, 'checkpoint.pth')

Loading then becomes a matter of reconstructing the model and optimizer states accordingly:

checkpoint = torch.load('checkpoint.pth', map_location=device)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
if checkpoint['scheduler_state'] is not None:
    scheduler.load_state_dict(checkpoint['scheduler_state'])
start_epoch = checkpoint['epoch'] + 1

This pattern ensures that training can resume exactly where it left off, crucial for long-running experiments or interrupted training sessions.

Another subtle but important point is to avoid saving the entire model object (torch.save(model)) in production or shared environments. It tightly couples the saved artifact to the exact Python environment and source code layout. Instead, keep the model class definitions versioned and deploy the code alongside the weights. This also facilitates model auditing and debugging, since the code is explicit and visible.

For models with complex architectures or those using third-party modules, consider implementing a from_config class method or similar factory pattern. This method should reconstruct the model architecture from a saved configuration dictionary, which you store alongside the weights. For example:

class ComplexModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return self.layer2(x)

    @classmethod
    def from_config(cls, config):
        return cls(config['input_dim'], config['hidden_dim'], config['output_dim'])

# Save config and state dict
config = {'input_dim': 10, 'hidden_dim': 20, 'output_dim': 1}
torch.save({'config': config, 'state_dict': model.state_dict()}, 'complex_model.pth')

# Load
checkpoint = torch.load('complex_model.pth')
model = ComplexModel.from_config(checkpoint['config'])
model.load_state_dict(checkpoint['state_dict'])

Using this method decouples your saved weights from hard-coded architectural parameters, making your checkpoints more portable and self-describing.

Finally, always validate your loaded model immediately after restoration. Run a forward pass with test data or unit tests to confirm that the weights were loaded correctly and no subtle mismatches or data corruption occurred. Silent errors in loading can manifest as degraded model performance or outright runtime errors much later in the pipeline.

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *