使用 Python 实现一个简单的图像分类器
在深度学习领域,图像分类是一个非常基础且重要的任务。它旨在根据输入的图像内容将其归类到预定义的类别中。本文将介绍如何使用 Python 和流行的深度学习框架 PyTorch 来构建一个简单的图像分类器。
我们将使用经典的 CIFAR-10 数据集进行训练和测试。该数据集包含 60,000 张 32x32 彩色图像,分为 10 个类别(如飞机、汽车、鸟等)。
环境准备
首先确保你已经安装了以下库:
pip install torch torchvision matplotlib
步骤一:导入必要的库
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt
步骤二:加载并预处理 CIFAR-10 数据集
我们使用 torchvision.datasets.CIFAR10
来下载并加载数据集,并对图像进行标准化处理。
# 图像预处理:将像素值从 [0, 1] 映射到 [-1, 1]transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练集和测试集train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
我们可以先可视化一些训练图像来确认是否正确加载。
def imshow(img): img = img / 2 + 0.5 # 反标准化 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()# 获取一批训练图像dataiter = iter(train_loader)images, labels = next(dataiter)# 展示图像imshow(torchvision.utils.make_grid(images))print(' '.join(f'{classes[labels[j]]}' for j in range(4)))
步骤三:构建神经网络模型
我们使用一个简单的卷积神经网络(CNN)来进行图像分类。
class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = torch.flatten(x, 1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return xnet = SimpleCNN()
步骤四:定义损失函数和优化器
我们使用交叉熵损失函数(CrossEntropyLoss)和随机梯度下降(SGD)作为优化器。
criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
步骤五:训练模型
我们只训练几个 epoch(轮次),以快速验证流程。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")net.to(device)for epoch in range(2): # 多次遍历数据集 running_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data[0].to(device), data[1].to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 200 == 199: # 每200个小批量打印一次 print(f'[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 200:.3f}') running_loss = 0.0print('Finished Training')
步骤六:测试模型性能
我们计算模型在测试集上的准确率。
correct = 0total = 0with torch.no_grad(): for data in test_loader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
步骤七:查看每类的分类准确率
class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad(): for data in test_loader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1for i in range(10): print(f'Accuracy of {classes[i]} : {100 * class_correct[i] / class_total[i]:.2f}%')
总结
在本篇文章中,我们使用 PyTorch 构建了一个简单的 CNN 模型,并在 CIFAR-10 数据集上进行了训练与评估。虽然这个模型较为简单,但它展示了图像分类的基本流程:数据加载与预处理、模型构建、训练、评估。
你可以尝试改进模型结构(例如使用 ResNet、VGG 等经典网络)、增加训练轮数、调整超参数等方式进一步提高准确率。
附录:完整代码
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt# 数据预处理transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据集train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 展示图像def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()dataiter = iter(train_loader)images, labels = next(dataiter)imshow(torchvision.utils.make_grid(images))print(' '.join(f'{classes[labels[j]]}' for j in range(4)))# 定义网络class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = torch.flatten(x, 1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return xnet = SimpleCNN()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")net.to(device)# 损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# 训练模型for epoch in range(2): running_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data[0].to(device), data[1].to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 200 == 199: print(f'[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 200:.3f}') running_loss = 0.0print('Finished Training')# 测试模型correct = 0total = 0with torch.no_grad(): for data in test_loader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')# 查看各类别准确率class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad(): for data in test_loader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1for i in range(10): print(f'Accuracy of {classes[i]} : {100 * class_correct[i] / class_total[i]:.2f}%')
如果你有兴趣深入学习计算机视觉或深度学习,欢迎继续关注我的后续文章!
免责声明:本文来自网站作者,不代表CIUIC的观点和立场,本站所发布的一切资源仅限用于学习和研究目的;不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负。本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。客服邮箱:ciuic@ciuic.com