使用 Python 实现一个简单的图像分类器

今天 5阅读

在本文中,我们将使用 Python 和深度学习框架 PyTorch 来实现一个简单的图像分类器。这个分类器将能够识别 CIFAR-10 数据集中的 10 种常见物体类别,如飞机、汽车、鸟类等。我们将从数据准备、模型构建、训练到评估的全过程进行详细讲解,并提供完整的代码示例。


环境准备

在开始之前,请确保你已经安装了以下库:

torch:PyTorch 深度学习框架torchvision:包含常用数据集和预训练模型matplotlib:用于可视化图像

你可以通过以下命令安装这些库(如果尚未安装):

pip install torch torchvision matplotlib

加载和预处理数据

我们使用 CIFAR-10 数据集作为训练和测试数据。该数据集包含 60,000 张 32x32 彩色图像,分为 10 个类别。

2.1 导入必要的模块

import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as np

2.2 定义数据变换并加载数据集

transform = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                          shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,                                       download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4,                                         shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat',           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

2.3 可视化部分训练图像

def imshow(img):    img = img / 2 + 0.5     # unnormalize    npimg = img.numpy()    plt.imshow(np.transpose(npimg, (1, 2, 0)))    plt.show()# 获取一批训练图像dataiter = iter(trainloader)images, labels = next(dataiter)# 显示图像imshow(torchvision.utils.make_grid(images))# 打印标签print(' '.join(f'{classes[labels[j]]}' for j in range(4)))

定义神经网络模型

我们将使用一个简单的卷积神经网络(CNN)来处理图像数据。

import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(3, 6, 5)        self.pool = nn.MaxPool2d(2, 2)        self.conv2 = nn.Conv2d(6, 16, 5)        self.fc1 = nn.Linear(16 * 5 * 5, 120)        self.fc2 = nn.Linear(120, 84)        self.fc3 = nn.Linear(84, 10)    def forward(self, x):        x = self.pool(F.relu(self.conv1(x)))        x = self.pool(F.relu(self.conv2(x)))        x = x.view(-1, 16 * 5 * 5)        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Net()

定义损失函数和优化器

我们使用交叉熵损失函数和随机梯度下降(SGD)优化器。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

训练模型

我们将在训练数据上训练模型,并每隔一定步数输出损失值。

for epoch in range(2):  # loop over the dataset multiple times    running_loss = 0.0    for i, data in enumerate(trainloader, 0):        # 获取输入数据        inputs, labels = data        # 清空梯度        optimizer.zero_grad()        # 前向传播 + 反向传播 + 优化        outputs = net(inputs)        loss = criterion(outputs, labels)        loss.backward()        optimizer.step()        # 打印统计信息        running_loss += loss.item()        if i % 2000 == 1999:    # 每2000个小批量打印一次            print(f'Epoch {epoch + 1}, Batch {i + 1} loss: {running_loss / 2000:.3f}')            running_loss = 0.0print('Finished Training')

保存模型

训练完成后,我们可以将模型保存下来,以便后续使用。

PATH = './cifar_net.pth'torch.save(net.state_dict(), PATH)

测试模型性能

接下来,我们在测试集上评估模型的表现。

correct = 0total = 0with torch.no_grad():    for data in testloader:        images, labels = data        outputs = net(images)        _, predicted = torch.max(outputs.data, 1)        total += labels.size(0)        correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

按类别查看准确率

我们还可以进一步分析每个类别的分类准确率。

class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad():    for data in testloader:        images, labels = data        outputs = net(images)        _, predicted = torch.max(outputs, 1)        c = (predicted == labels).squeeze()        for i in range(4):            label = labels[i]            class_correct[label] += c[i].item()            class_total[label] += 1for i in range(10):    print(f'Accuracy of {classes[i]} : {100 * class_correct[i] / class_total[i]:.2f}%')

总结

通过本文,我们完成了以下工作:

使用 PyTorch 加载并预处理 CIFAR-10 图像数据;构建了一个简单的 CNN 网络模型;使用交叉熵损失函数和 SGD 优化器训练模型;在测试集上评估模型性能,并输出各类别准确率;将模型保存为文件以备后续使用。

虽然我们使用的模型比较简单,但在 CIFAR-10 上已经可以达到约 60% 左右的准确率。读者可以尝试使用更复杂的模型结构(如 ResNet、VGG)、调整超参数或使用数据增强技术来进一步提高准确率。


十、扩展建议

使用 GPU 加速训练过程(需要 CUDA 支持)添加数据增强(如旋转、翻转)使用预训练模型(如 resnet18)进行迁移学习使用 TensorBoard 进行训练可视化

如果你对图像分类感兴趣,欢迎继续深入研究计算机视觉与深度学习领域,这是一个充满挑战和机遇的方向!


完整代码地址:你可以将上述所有代码片段依次复制到一个 .py 文件中运行,或者在 Jupyter Notebook 中分块执行。

如有任何问题,欢迎留言讨论!

免责声明:本文来自网站作者,不代表CIUIC的观点和立场,本站所发布的一切资源仅限用于学习和研究目的;不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负。本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。客服邮箱:ciuic@ciuic.com

目录[+]

您是本站第32476名访客 今日有30篇新文章

微信号复制成功

打开微信,点击右上角"+"号,添加朋友,粘贴微信号,搜索即可!