请稍等 ...
×

采纳答案成功!

向帮助你的同学说点啥吧!感谢那些助人为乐的人

分享一下,如何获取mnist数据集的问题

查了一晚,终于搞定。分享如下:

一、4个文件,百度网盘自取,路径自己处理一下

二、idx1_ubyte、idx3_ubyte类型文件转为我们熟悉的数据集

从网上搬来的,稍微改动了一下,直接调用最下方 load()方法 即可

import numpy as np
import struct


def _decode_idx3_ubyte(idx3_ubyte_file):
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, "rb").read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = ">iiii"
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)
    fmt_image = ">" + str(image_size) + "B"
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)
    return images


def _decode_idx1_ubyte(idx1_ubyte_file):
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, "rb").read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = ">ii"
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = ">B"
    labels = np.empty(num_images)
    for i in range(num_images):
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def _load_X_train(X_train_path):
    return _decode_idx3_ubyte(X_train_path)


def _load_y_train(y_train_path):
    return _decode_idx1_ubyte(y_train_path)


def _load_X_test(X_test_path):
    return _decode_idx3_ubyte(X_test_path)


def _load_y_test(y_test_path):
    return _decode_idx1_ubyte(y_test_path)


def load(X_train_path, X_test_path, y_train_path, y_test_path):
    return _load_X_train(X_train_path), _load_X_test(X_test_path), _load_y_train(y_train_path), _load_y_test(y_test_path)

三、传入路径做为参数,调用load方法

import numpy as np
from utils.data_init import mnist
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsRegressor

np.set_printoptions(suppress=True)

if __name__ == "__main__":
    X_train, X_test, y_train, y_test = mnist.load(
        "../datasets/mnist/X_train.idx3-ubyte",
        "../datasets/mnist/X_test.idx3-ubyte",
        "../datasets/mnist/y_train.idx1-ubyte",
        "../datasets/mnist/y_test.idx1-ubyte"
    )
    print("X_train={}, X_test={}, y_train={}, y_test={}".format(X_train.shape, X_test.shape, y_train.shape, y_test.shape))

    X_train = X_train.reshape(X_train.shape[0], -1)
    X_test = X_test.reshape(X_test.shape[0], -1)
    y_train = y_train.reshape(y_train.shape[0], -1)
    y_test = y_test.reshape(y_test.shape[0], -1)

    print("X_train={}, X_test={}, y_train={}, y_test={}".format(X_train.shape, X_test.shape, y_train.shape, y_test.shape))

    # kNN
    # knn_clf = KNeighborsRegressor()
    # knn_clf.fit(X_train, y_train)  # 1min左右
    # knn_score = knn_clf.score(X_test, y_test)  # 15mins左右
    # print(knn_score)

    # PCA + kNN
    pca = PCA(n_components=0.6)
    pca.fit(X=X_train)
    print("使用PCA将数据降维至dim={},即可满足需要".format(pca.n_components_))
    print("explained_variance_ratio_: ", np.sum(pca.explained_variance_ratio_))

    X_train_reduction = pca.transform(X=X_train)
    X_test_reduction = pca.transform(X=X_test)

    knn_clf = KNeighborsRegressor()
    knn_clf.fit(X=X_train_reduction, y=y_train)
    pca_knn_score = knn_clf.score(X=X_test_reduction, y=y_test)
    print(pca_knn_score)

    # PCA降维后再使用kNN,效果反而比单独使用kNN的效果好,为什么?sklearn的PCA内部降噪

    pass

附打印结果:

X_train=(60000, 28, 28), X_test=(10000, 28, 28), y_train=(60000,), y_test=(10000,)
X_train=(60000, 784), X_test=(10000, 784), y_train=(60000, 1), y_test=(10000, 1)
使用PCA将数据降维至dim=17,即可满足需要
explained_variance_ratio_:  0.6074124552802459
0.938264058128613

PS:老师给的网盘数据好像不对…测试有一些问题,每次测试结果都是负数…我也不知道为啥(不管了 碎觉 爱咋咋地)

正在回答

1回答

感谢分享。


现在最新版本的 sklearn,使用 fetch_openml 可以直接获取 mnist 数据集,可以参考这里:http://coding.imooc.com/learn/questiondetail/139012.html


继续加油!:)

0 回复 有任何疑惑可以回复我~
  • 提问者 刘刘刘刘刘英迪 #1
    非常感谢!
    回复 有任何疑惑可以回复我~ 2020-04-12 18:02:40
  • 提问者 刘刘刘刘刘英迪 #2
    老师,你分享的那个我看过了,每次提问一个问题我都会在问答区查看,查不到的时候才会发起提问,这也是对老师的尊敬。
    
    from sklearn.datasets import fetch_openml
    我尝试过了,这句,编译会标红报错。
    回复 有任何疑惑可以回复我~ 2020-04-12 18:04:23
  • liuyubobobo 回复 提问者 刘刘刘刘刘英迪 #3
    额?你的 sklearn 是什么版本?
    回复 有任何疑惑可以回复我~ 2020-04-13 02:47:59
问题已解决,确定采纳
还有疑问,暂不采纳
意见反馈 帮助中心 APP下载
官方微信