图像识别intro
Wednesday, November 16, 2022
本文共304字
1分钟阅读时长
⚠️本文是作者P3troL1er原创,首发于https://peterliuzhi.top/principle/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/softmax/%E5%9B%BE%E5%83%8F%E8%AF%86%E5%88%ABintro/。商业转载请联系作者获得授权,非商业转载请注明出处!
It always seems impossible until it’s done.
— Nelson Mandela
使用 torchvision 库
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms
从 MNIST 获取数据
# 获得的是Mnist中标注过的数据
trained_set = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=False,
transform=transforms.ToTensor())
定义画图函数
# 定义画图函数
def draw(features, labels):
# 设置图表大小
plt.rcParams['figure.figsize'] = (8.0, 6.0)
# 创建子图
_, axs = plt.subplots(10, len(features) // 10)
# 调整间距
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=2)
assert isinstance(axs, np.ndarray) # 类型断言
# 将多维数组转化为一维数组
axs = axs.reshape((1, -1))[0]
# print(axs)
# 对每一个子图单独设置它们的像素值和标签
for ax, img, lbl in zip(axs, features, labels):
# 像素值放置
ax.imshow((img.view((28, 28))).numpy())
# 设置标签
ax.set_title(lbl)
# 取消显示轴
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
plt.show()
将数据集里的 labels 映射为字符串
# 设置每一个数字的映射值
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
打印 features 和 labels
# 把训练好的数据连同每一张图对应的识别值打印出来
draw([trained_set[i][0] for i in range(30)], get_fashion_mnist_labels([trained_set[i][1] for i in range(30)]))
扫码阅读此文章
点击按钮复制分享信息
点击订阅