图像识别intro

Wednesday, November 16, 2022
本文共304字
1分钟阅读时长

⚠️本文是作者P3troL1er原创,首发于https://peterliuzhi.top/principle/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/softmax/%E5%9B%BE%E5%83%8F%E8%AF%86%E5%88%ABintro/。商业转载请联系作者获得授权,非商业转载请注明出处!

It always seems impossible until it’s done. — Nelson Mandela

使用 torchvision 库

import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms as transforms

从 MNIST 获取数据

# 获得的是Mnist中标注过的数据
trained_set = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=False,
                                                transform=transforms.ToTensor())

定义画图函数

# 定义画图函数
def draw(features, labels):
    # 设置图表大小
    plt.rcParams['figure.figsize'] = (8.0, 6.0)
    # 创建子图
    _, axs = plt.subplots(10, len(features) // 10)
    # 调整间距
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=2)
    assert isinstance(axs, np.ndarray)  # 类型断言
    # 将多维数组转化为一维数组
    axs = axs.reshape((1, -1))[0]
    # print(axs)
    # 对每一个子图单独设置它们的像素值和标签
    for ax, img, lbl in zip(axs, features, labels):
        # 像素值放置
        ax.imshow((img.view((28, 28))).numpy())
        # 设置标签
        ax.set_title(lbl)
        # 取消显示轴
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
    plt.show()

将数据集里的 labels 映射为字符串

# 设置每一个数字的映射值
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

打印 features 和 labels

# 把训练好的数据连同每一张图对应的识别值打印出来
draw([trained_set[i][0] for i in range(30)], get_fashion_mnist_labels([trained_set[i][1] for i in range(30)]))