- Simple Neural Network Model
- Training Model and Saving it(.pth),
- Loading model and using it for prediction.
We'll use a small dataset for demonstration, like the classic MNIST dataset, which consists of handwritten digits.
Step 1: Import Libraries and Define the Model
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate the model, define loss function and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Step 2: Load the Dataset and Train the Model
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Train the model
for epoch in range(1): # Train for 1 epoch for simplicity
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Training complete!')
Step 3: Save the Model
# Save the model state dictionary
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved!')
Step 4: Load the Model and Make Predictions
# Load the model state dictionary
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('simple_nn.pth'))
loaded_model.eval() # Set the model to evaluation mode
# Make a prediction on a single image
test_image, label = trainset[20] # Use the 20th image from the training set as an example
test_image = test_image.unsqueeze(0) # Add a batch dimension
# Display the image
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Actual Label: {label}')
plt.axis('off')
plt.show()
output = loaded_model(test_image)
_, predicted = torch.max(output, 1)
print('Predicted label:', predicted.item())
Hope you found this post helpful and enjoyable.
Thank you!