edo1z blog

プログラミングなどに関するブログです

Python - MNISTを使う

機械学習で使えるサンプル画像の有名なのがMNISTだそうです。0-9までの手書き文字画像と、正解ラベルデータが、トレーニング用とテスト用で分けられています。 http://yann.lecun.com/exdb/mnist/

バイナリデータになっていて、画像等は全部データとしてつながっているらしい。機械学習は画像を数値として扱う必要があります。

参考:https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/dataset/mnist.py

Pythonでダウンロード・読み込み

ダウンロード

import urllib.request
import os.path

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}
dataset_dir = os.path.dirname(os.path.abspath(__file__))

def _download(filename):
    file_path = dataset_dir + '/' + filename
    if os.path.exists(file_path):
        return print('already exist')
    print('Downloading ' + filename + ' ...')
    urllib.request.urlretrieve(url_base + filename, file_path)
    print('Done')

def download_mnist():
    for v in key_file.values():
       _download(v)

download_mnist()

読み込み

gzipを使います。

def load_mnist(filename):
    file_path = dataset_dir + '/' + filename
    with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    return data.reshape(-1, img_size)

試しに画像データを読み込んで1つ表示させてみます。

data = load_mnist(key_file['train_img'])
img1 = data[0].reshape(28, 28)
pil_img = Image.fromarray(np.uint8(img1))
pil_img.show()

手書きの5っぽい画像が表示されました。 load_mnist関数は、gzipで読み込んで、np.frombufferというので、Numpyの配列に格納してます。data.reshapeで、画像サイズの784ずつに配列を分割してます。 784ずつに分割された配列で実際には機械学習すればいいと思います。実際の画像にするために、さらに1つの画像の縦横に合わせて配列を変換してます。