博客
关于我
深度学习入门基础教程(三) CNN做MNIST数据集图像分类 pytorch版代码
阅读量:208 次
发布时间:2019-02-28

本文共 6211 字,大约阅读时间需要 20 分钟。

一. 原理解读

这篇文章以MNIST数据集为基础,模仿经典的CNN图像分类教程,旨在帮助读者理解卷积神经网络(CNN)的核心原理。MNIST数据集包含60000张训练图片和10000张测试图片,每张图片大小为28x28,单通道,数字0-9对应标签。

二. PyTorch版完整代码复现

以下是基于PyTorch实现的完整CNN代码,能够实现接近99%的准确率:
import torchimport torch.nn as nnfrom torchvision.datasets import MNISTfrom torchvision import transformsfrom torch.utils.data import DataLoader, Datasetfrom torch.optim import Adamimport osimport shutilclass Unit(nn.Module):    def __init__(self, inc, ouc):        super(Unit, self).__init__()        self.unit_net = nn.Sequential(            nn.Conv2d(inc, ouc, kernel_size=3, padding=1),            nn.BatchNorm2d(ouc),            nn.ReLU()        )    def forward(self, x):        return self.unit_net(x)class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.net = nn.Sequential(            Unit(1, 32),            Unit(32, 32),            Unit(32, 32),            nn.MaxPool2d(2),            Unit(32, 64),            Unit(64, 64),            Unit(64, 64),            Unit(64, 64),            nn.MaxPool2d(2),            Unit(64, 128),            Unit(128, 128),            Unit(128, 128),            Unit(128, 128),            nn.MaxPool2d(2),            Unit(128, 128),            Unit(128, 128),            Unit(128, 128),            nn.AvgPool2d(4)        )        self.fc = nn.Linear(128, 10)        def forward(self, x):        y = self.net(x)        y = y.view(-1, 128)        return self.fc(y)# 数据增强与加载train_transforms = transforms.Compose([    transforms.RandomHorizontalFlip(),    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_set = MNIST('./data/', train=True, transform=train_transforms, download=True)train_dataloader = DataLoader(train_set, batch_size=512, shuffle=True)test_transforms = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])test_set = MNIST('./data/', train=False, transform=test_transforms, download=True)test_dataloader = DataLoader(test_set, batch_size=512, shuffle=False)param_path = r'./param/mnist_cnn.pkl'tmp_param_path = r'./param/mnist_cnn_temp.pkl'# 超参数batch_size = 512learning_rate = 0.001weight_decay = 0.0001num_epoch = 1000# 模型定义与训练module = Net()if torch.cuda.is_available():    module.cuda()optimizer = Adam(module.parameters(), lr=learning_rate, weight_decay=weight_decay)loss_f = nn.CrossEntropyLoss()def adjust_lr_rate(epoch):    lr = learning_rate    if epoch > 180:        lr = lr / 1000000    elif epoch > 150:        lr = lr / 100000    elif epoch > 120:        lr = lr / 10000    elif epoch > 90:        lr = lr / 1000    elif epoch > 60:        lr = lr / 100    elif epoch > 30:        lr = lr / 10    for param_group in optimizer.param_groups:        param_group['lr'] = lrdef train(num_epoch):    global module, optimizer, loss_f, param_path, tmp_param_path    if os.path.exists(param_path):        module.load_state_dict(torch.load(param_path))        best_acc = 0    for epoch in range(num_epoch):        train_loss = 0        train_acc = 0        for images, labels in train_dataloader:            if torch.cuda.is_available():                images = images.cuda()                labels = labels.cuda()            outputs = module(images)            loss = loss_f(outputs, labels)            optimizer.zero_grad()            loss.backward()            optimizer.step()                        train_loss += loss.item() * images.size(0)            _, prediction = torch.max(outputs, 1)            train_acc += torch.sum(prediction == labels)                adjust_lr_rate(epoch)        train_loss = train_loss / 60000        train_acc = train_acc.item() / 60000                test_acc = 0        module.eval()        for images, labels in test_dataloader:            if torch.cuda.is_available():                images = images.cuda()                labels = labels.cuda()            outputs = module(images)            _, prediction = torch.max(outputs, 1)            test_acc += torch.sum(prediction == labels)        test_acc = test_acc.item() / 10000                if test_acc > best_acc:            best_acc = test_acc            if os.path.exists(tmp_param_path):                shutil.copyfile(tmp_param_path, param_path)            torch.save(module.state_dict(), tmp_param_path)                print(f'Epoch: {epoch}, Train Loss: {train_loss}, Train Acc: {train_acc}, Test Acc: {test_acc}')# 训练与测试train(1000)

三. 扩展

通过上述训练好的模型可以进行图像分类任务。以下是推断代码:
import torchimport torch.nn as nnfrom torchvision import transformsfrom torchvision.models import squeezenet1_1from PIL import Imageclass Unit(nn.Module):    def __init__(self, inc, ouc):        super(Unit, self).__init__()        self.unit_net = nn.Sequential(            nn.Conv2d(inc, ouc, kernel_size=3, padding=1),            nn.BatchNorm2d(ouc),            nn.ReLU()        )    def forward(self, x):        return self.unit_net(x)class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.net = nn.Sequential(            Unit(1, 32),            Unit(32, 32),            Unit(32, 32),            nn.MaxPool2d(2),            Unit(32, 64),            Unit(64, 64),            Unit(64, 64),            Unit(64, 64),            nn.MaxPool2d(2),            Unit(64, 128),            Unit(128, 128),            Unit(128, 128),            Unit(128, 128),            nn.MaxPool2d(2),            Unit(128, 128),            Unit(128, 128),            Unit(128, 128),            nn.AvgPool2d(4)        )        self.fc = nn.Linear(128, 10)        def forward(self, x):        y = self.net(x)        y = y.view(-1, 128)        return self.fc(y)# 加载预训练模型module = squeezenet1_1(pretrained=True)# 定义预测函数def predict_img(img_path):    img = Image.open(img_path).convert('L')    img = transforms.Resize(28)(img)    img = transforms.ToTensor()(img)    img = img.unsqueeze(0)        if torch.cuda.is_available():        img = img.cuda()        outputs = module(img)    _, index = torch.max(outputs, 1)    return index.item()# 示例预测index = predict_img(r'C:\Users\87419\Desktop\00.jpg')print(index)

四. 灰度图转换示例

将RGB图像转换为灰度图像的代码:
from PIL import ImageI = Image.open(r'C:\Users\87419\Desktop\0.jpg').convert('L')I.save(r'C:\Users\87419\Desktop\00.jpg')

注:测试图片需为灰度图像。

转载地址:http://xgpi.baihongyu.com/

你可能感兴趣的文章
Objective-C实现iterative merge sort迭代归并排序算法(附完整源码)
查看>>
Objective-C实现jaccard similarity相似度无平方因子数算法(附完整源码)
查看>>
Objective-C实现Julia集算法(附完整源码)
查看>>
Objective-C实现jump search跳转搜索算法(附完整源码)
查看>>
Objective-C实现jumpSearch跳转搜索算法(附完整源码)
查看>>
Objective-C实现k nearest neighbours k最近邻分类算法(附完整源码)
查看>>
Objective-C实现k-means clustering均值聚类算法(附完整源码)
查看>>
Objective-C实现k-Means算法(附完整源码)
查看>>
Objective-C实现k-nearest算法(附完整源码)
查看>>
Objective-C实现KadaneAlgo计算给定数组的最大连续子数组和算法(附完整源码)
查看>>
Objective-C实现kadanes卡达内斯算法(附完整源码)
查看>>
Objective-C实现kahns algorithm卡恩算法(附完整源码)
查看>>
Objective-C实现karatsuba大数相乘算法(附完整源码)
查看>>
Objective-C实现karger算法(附完整源码)
查看>>
Objective-C实现KMP搜索算法(附完整源码)
查看>>
Objective-C实现Knapsack problem背包问题算法(附完整源码)
查看>>
Objective-C实现knapsack背包问题算法(附完整源码)
查看>>
Objective-C实现knapsack背包问题算法(附完整源码)
查看>>
Objective-C实现knight tour骑士之旅算法(附完整源码)
查看>>
Objective-C实现knight Tour骑士之旅算法(附完整源码)
查看>>