使用 Python 实现一个简单的图像分类器
在本文中,我们将使用 Python 和深度学习框架 PyTorch 来实现一个简单的图像分类器。这个分类器将能够识别 CIFAR-10 数据集中的 10 种常见物体类别,如飞机、汽车、鸟类等。我们将从数据准备、模型构建、训练到评估的全过程进行详细讲解,并提供完整的代码示例。
环境准备
在开始之前,请确保你已经安装了以下库:
torch
:PyTorch 深度学习框架torchvision
:包含常用数据集和预训练模型matplotlib
:用于可视化图像你可以通过以下命令安装这些库(如果尚未安装):
pip install torch torchvision matplotlib
加载和预处理数据
我们使用 CIFAR-10 数据集作为训练和测试数据。该数据集包含 60,000 张 32x32 彩色图像,分为 10 个类别。
2.1 导入必要的模块
import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as np
2.2 定义数据变换并加载数据集
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2.3 可视化部分训练图像
def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()# 获取一批训练图像dataiter = iter(trainloader)images, labels = next(dataiter)# 显示图像imshow(torchvision.utils.make_grid(images))# 打印标签print(' '.join(f'{classes[labels[j]]}' for j in range(4)))
定义神经网络模型
我们将使用一个简单的卷积神经网络(CNN)来处理图像数据。
import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return xnet = Net()
定义损失函数和优化器
我们使用交叉熵损失函数和随机梯度下降(SGD)优化器。
import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
训练模型
我们将在训练数据上训练模型,并每隔一定步数输出损失值。
for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 获取输入数据 inputs, labels = data # 清空梯度 optimizer.zero_grad() # 前向传播 + 反向传播 + 优化 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss += loss.item() if i % 2000 == 1999: # 每2000个小批量打印一次 print(f'Epoch {epoch + 1}, Batch {i + 1} loss: {running_loss / 2000:.3f}') running_loss = 0.0print('Finished Training')
保存模型
训练完成后,我们可以将模型保存下来,以便后续使用。
PATH = './cifar_net.pth'torch.save(net.state_dict(), PATH)
测试模型性能
接下来,我们在测试集上评估模型的表现。
correct = 0total = 0with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
按类别查看准确率
我们还可以进一步分析每个类别的分类准确率。
class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i in range(4): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1for i in range(10): print(f'Accuracy of {classes[i]} : {100 * class_correct[i] / class_total[i]:.2f}%')
总结
通过本文,我们完成了以下工作:
使用 PyTorch 加载并预处理 CIFAR-10 图像数据;构建了一个简单的 CNN 网络模型;使用交叉熵损失函数和 SGD 优化器训练模型;在测试集上评估模型性能,并输出各类别准确率;将模型保存为文件以备后续使用。虽然我们使用的模型比较简单,但在 CIFAR-10 上已经可以达到约 60% 左右的准确率。读者可以尝试使用更复杂的模型结构(如 ResNet、VGG)、调整超参数或使用数据增强技术来进一步提高准确率。
十、扩展建议
使用 GPU 加速训练过程(需要 CUDA 支持)添加数据增强(如旋转、翻转)使用预训练模型(如 resnet18)进行迁移学习使用 TensorBoard 进行训练可视化如果你对图像分类感兴趣,欢迎继续深入研究计算机视觉与深度学习领域,这是一个充满挑战和机遇的方向!
完整代码地址:你可以将上述所有代码片段依次复制到一个 .py
文件中运行,或者在 Jupyter Notebook 中分块执行。
如有任何问题,欢迎留言讨论!
免责声明:本文来自网站作者,不代表CIUIC的观点和立场,本站所发布的一切资源仅限用于学习和研究目的;不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负。本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。客服邮箱:ciuic@ciuic.com