avatar

FashionMnist数据集卷积神经网络实战

数据集介绍

 Fashion-MNIST是Zalando的商品图片数据集,其中包含60,000个示例的训练集和10,000个示例的测试集。 每个示例都是一个28×28灰度图像,与来自10个类别的标签关联。 Fashion-MNIST旨在直接替代原始MNIST数据集,以对机器学习算法进行基准测试。

样本示例


网络结构


Pytorch实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import os

batch_size = 60

def get_mean_std(dataset, ratio=0.01):
"""计算数据集的均值和方差
"""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=int(len(dataset)*ratio),
shuffle=True, num_workers=4)
train = iter(dataloader).next()[0] # 一个batch的数据
#print(train.numpy().shape())
mean = np.mean(train.numpy(), axis=(0,2,3))
std = np.std(train.numpy(), axis=(0,2,3))
return mean, std


# 对数据的处理:神经网络希望输入的数据最好比较小,最好处于(-1,1)内,最好符合正态分布。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.2888097,), (0.3549146,))
])

# 数据集准备
train_dataset = datasets.FashionMNIST(root='./dataset/fmnist/', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./dataset/fmnist/', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 测试不需要shuffle打乱顺序,保证结果的顺序


# 模型类
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=1) # 卷积层1
self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=3) # 卷积层2
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3)
self.conv4 = torch.nn.Conv2d(128,128, kernel_size=3)
self.linear5 = torch.nn.Linear(2048, 512) # 全连接层3
self.linear6 = torch.nn.Linear(512, 10)
self.bn1 = torch.nn.BatchNorm2d(64)
self.bn2 = torch.nn.BatchNorm2d(128)
self.pooling = torch.nn.MaxPool2d(2) # 池化层
self.drop1 = torch.nn.Dropout2d()

def forward(self, x):
# x (batch_size,channel,width,height) 28
batch_size = x.size(0) # minibatch样本数量
x = torch.nn.functional.relu(self.bn1(self.pooling(self.conv2(self.conv1(x))))) # 12
x = torch.nn.functional.relu(self.bn2(self.pooling(self.conv4(self.conv3(x))))) # 4
x = x.view(batch_size, -1) # 转为全连接层允许的输入格式
x = torch.nn.functional.relu(self.linear5(x))
x = self.drop1(x)
return self.linear6(x) # 后面要用softmax,故最后一层不做relu激活

model = Net()

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # 随机梯度下降,带冲量

e_list = []
l_list = []
running_loss = 0.0


# 单轮训练的函数
def train(epoch):
running_loss = 0.0
Loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()

# 前馈计算
outputs = model(inputs)
# 损失计算
loss = criterion(outputs, target)
# 反馈计算
loss.backward()
optimizer.step()

running_loss += loss.item() # 累加损失
Loss += loss.item()
# 每300次迭代(minibatch)训练更新,计算一次平均损失
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
e_list.append(epoch)
l_list.append(running_loss / 300)


def test():
correct = 0 # 预测正确数
total = 0 # 总样本数
with torch.no_grad(): # 声明不计算梯度
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1) # 按列找最大值的下标,返回两个:最大值,下标
total += labels.size(0) # labels矩阵的行数
correct += (predicted == labels).sum().item() # 相等为1,否则为0
print('Accuracy on test set: %d %%' % (100 * correct / total))


# 训练
if __name__ == '__main__':
# 加载模型
if os.path.exists("./model/model_params.pkl"):
model.load_state_dict(torch.load("./model/model_params.pkl"))
# 训练+测试
for epoch in range(10):
train(epoch)
test()
# 保存模型的数据
torch.save(model.state_dict(), './model/model_params.pkl')

 该实现最好成绩为91%, 不甚理想, 待优化。

文章作者: Liam
文章链接: https://www.ccyh.xyz/p/1f07.html
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Liam's Blog
ღ喜欢记得五星好评哦~
打赏
  • 微信
    微信
  • 支付寶
    支付寶

评论