Building Basic model for Understanding ML

Santosh Premi Adhikari - Aug 31 - - Dev Community
  1. Simple Neural Network Model
  2. Training Model and Saving it(.pth),
  3. Loading model and using it for prediction.

We'll use a small dataset for demonstration, like the classic MNIST dataset, which consists of handwritten digits.

Step 1: Import Libraries and Define the Model

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model, define loss function and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Enter fullscreen mode Exit fullscreen mode

Step 2: Load the Dataset and Train the Model

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)


# Train the model
for epoch in range(1):  # Train for 1 epoch for simplicity
    for images, labels in trainloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

print('Training complete!')
Enter fullscreen mode Exit fullscreen mode

Step 3: Save the Model

# Save the model state dictionary
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved!')
Enter fullscreen mode Exit fullscreen mode

Step 4: Load the Model and Make Predictions

# Load the model state dictionary
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('simple_nn.pth'))
loaded_model.eval()         # Set the model to evaluation mode

# Make a prediction on a single image
test_image, label = trainset[20]   # Use the 20th image from the training set as an example
test_image = test_image.unsqueeze(0)  # Add a batch dimension


# Display the image
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Actual Label: {label}')
plt.axis('off')
plt.show()

output = loaded_model(test_image)
_, predicted = torch.max(output, 1)

print('Predicted label:', predicted.item())

Enter fullscreen mode Exit fullscreen mode

Result

Hope you found this post helpful and enjoyable.
Thank you!

. . .
Terabox Video Player