PyTorchによるCNNを用いた画像分類【CIFAR-10】

AI

PyTorchによるCNNを用いた画像分類【CIFAR-10】

今回はPyTorchを用いて画像分類を行う方法をご紹介します。本記事では、CIFAR-10データセットを用いて画像分類を行うモデルを実装して行きます。これからPyTorchを使っていみたいという方やCNNがどのようなものか理解したいという方は是非、本記事をご一読ください。

本記事を読んでいただくこと以下のことを学ぶことができます。

  1. CNN(畳み込みニューラルネットワーク)とは何か?
  2. PyTorchでのCNNを用いた画像分類モデルを実装する方法
  3. PyTorchにおけるデータ拡張(Data Augmentation)の実現方法
オススメの学習サービス

機械学習で用いる数学の知識を基礎から学びたい初心者の方にはUdemyの以下の動画がオススメです。

興味がある方は是非こちらもチェックしてください。

Udemy – 【キカガク流】人工知能・機械学習 脱ブラックボックス講座 – 初級編 –

Udemy – 【キカガク流】人工知能・機械学習 脱ブラックボックス講座 – 中級編 –

CNN(畳み込みニューラルネットワーク)とは何か?

まずプログラミングを行う前にCNNとはどのようなものなのか簡単にご説明します。詳細な説明は専門書や論文に任せるとして、ここでは、短時間でCNNの概要を理解していただくことを目的に解説させていただきます。

まず初めにCNNとは何かという点ですが、CNNはConvolutional Neural Networkのの略で、直訳すると畳み込みニューラルネットワークと呼ばれます。

上の図では代表的なCNNのモデルであるVGG16を例として示しています。上の図を見ていたくと分かるようにCNNでは畳み込み層とMax Pooling層を複数層重ね学習を行います。このことから畳み込みニューラルネットワークと呼ばれます。

CNNは単純な全結合層のみのMLPよりも精度が高く、画像の学習ではほぼ必須と言って良いほど頻繁に利用されます。

畳み込みとは?

畳み込みを簡単に説明すると、一点の情報だけではなく、周りの情報も考慮して処理を行う手法と言えます。一点だけではなく、全体を見ることでより正しく画像を認識できるようになるという利点があります。

畳み込みでは、以下の図のようにカーネル(フィルタ)と呼ばれるものを左上から右下に掛け合わせていき、最終的に新しい画像を得ます。

カーネルには様々な種類がありますが、以下の例は平滑化フィルタというものを例として示しています。青の部分が着目画素を示しており、周りを含めた3 x 3の領域に同じく3 x 3のカーネルを適用しています。結果的に領域のすべの画素の総和を平均しています。これにより細かいノイズなどが除去されることになります。

PyTorchではtorch.nn.Conv2dを利用することで簡単に畳みを行うことができます。

パディングについて

以下の図で示しているように畳み込みを行う際にカーネルが画像からはみ出すことができないため、元の画像より縮小してしまうという問題が発生します。そこで、それを回避するためにパディングという工夫が行われます。

以下の図の左側では4 x 4の画像に3 x 3のカーネルを適用しようとしています。しかし、1ピクセル分はみ出してしまうため、一つ内側から畳み込みを行う必要があります。しかし、そうすると画像が2 x 2縮小してしまいます。そこで図の右側にあるように周りを一般的には0で埋めることで画像の縮小を防ぐことができます。これを0パディングと呼びます。

Max Poolingとは?

CNNでは畳み込みを行った後にプーリングを行います。プーリングを行うことにより、特徴をより際立たせるという効果がもたらされます。

よく利用されるプーリングの手法にMax Poolingというものがあります。Max Poolingを適用することによって下記の図のように解像度が引き下げられます。CNNでは、畳み込み層とプーリング層とが複数重なっており、異なる空間スケールの情報を学習することができるようになります。

例えば、4 x 4ピクセルの画像にMax Poolingを適用して2 x 2の画像とした場合、次の畳み込み層では1ピクセル隣の情報を取得するということは元の画像での2ピクセル分の情報を取得してることになります。つまり、解像度を下げることにより広い範囲の特徴を学習できるようになるということです。

Max Poolingでは下図のように各領域の最大値を取得することになります。青色のエリアでは最大値9が、黄色のエリアでは6が選択され、最終的に4 x 4 から 2 x 2に解像度が引き下げられます。

CIFAR-10データセットについて

CIFAR-10はairplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの十種の画像で構成されるデータセットです。画像は32 x 32ピクセルのRGBカラー画像となります。

引用: https://www.cs.toronto.edu/~kriz/cifar.html

PyTorchでのCNNを用いた画像分類モデルを実装する方法

ここからは実際にPyTorchを用いて、CNNを実装する方法を解説していきます。

実行環境

今回は簡単に動作環境を用意できるGoogle Colabを利用します。

プログラムを実装する前にGoogle ColabでGPUを利用するための設定を行います。画面上部にある[ランタイム]から[ランタイムのタイプを変更]クリックし、[ハードウェア アクセラレータ]をGPUに変更し、保存しておいてください。

ライブラリのインポート

最初に必要なライブラリをインポートします。

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

GPUが利用する準備をします。cudaはGPUのことを示しており、GPUが利用できる場合は、deviceにcudaを利用できない場合はcpuをセットしています。

device = "cuda" if torch.cuda.is_available() else "cpu"

前処理の定義

続いて、前処理を定義します。

ToTensor()では、画像をTensorクラスに変換しています。加えて、Channel Lastになっている画像をChannel Firstに変換し、0〜255の整数値を0〜1の浮動小数点数型に変換するところまで行ってくれます。

Normalize()では、平均と標準偏差に0.5を指定することで、値の範囲を[-1, 1]にしています。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5,))
])

データセットの読み込み

以下の処理では、CIFAR10の学習用データと検証用データをダウンロードし、それぞれ変数に読み込んでいます。

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
validation_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

データローダーの初期化

ここでは学習用のデータローダーと検証用のデータローダーをそれぞれ初期化しています。

  • 第1引数: データセット
  • batch_size: ミニバッチサイズ(今回は32としています)
  • shuffle引数: 画像をランダムに読み込むかどうか(Trueの場合はランダム、Falseの場合は前から順番に読み込みます)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

モデルの定義

__init__

こちらのメソッドでモデルの初期化を行っています。num_classesには分類するクラス数を渡します。今回は画像を10種類に分類するタスクですので、初期化時に10を指定することになります。

モデルの各層をSequentialを利用して定義しています。

一つ目のConv2dでは入力チャネル数にRGBの3を出力チャネル数に64を指定しています。カーネルサイズは5としており、それにあわせてpaddingは2をセットしています。そして、出力値に活性化関数ReLUを適用しています。さらに出力値に対してカーネルサイズを2としてMax Poolingを適用しています。

上記の畳み込み、Max Pooling、ReLUを複数重ねてネットワークを定義しています。

最後に10クラスに分類するようにLinearで全結合層を定義しています。Linearのin_featuresには1次元のベクトルを渡す必要があるのため、4 x 4 x 128としています。(Max Poolingを3回適用しているので、32 x 32 > 16 x 16 > 4x 4となっています。そしてチャンネル数は最後の畳み込みのout_channelsの値128となっています。)

forward

forwardメソッドでは順伝播の処理を定義しています。ここではxを順番に置き換えて処理を行っています。まず、入力値を上記で定義したfeaturesに渡しています。

そして、classifierに値を渡すためには一次元のベクトルにする必要がありますので、viewを利用して一次元のベクトルに変換しています。(x.size(0)はチャネル数, -1とすることで残りを一次元に押し込めています)

最後にclassifierで予測を行います。

class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Linear(in_features=4 * 4 * 128, out_features=num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

モデル・損失関数・最適化関数の初期化

ここではモデル、損失関数、最適化関数(オプティマイザ)をそれぞれ初期化します。

CNN(10)でモデルの初期化を行っています。引数には分類するクラス数10を指定しています。model.to(device)でGPUにモデルを送っています。

今回のタスクは多クラス分類であるため、多クラス分類で一般的に利用されるクロスエントロピーを損失関数として用います。

最適化関数には一般的によく使われるAdamを指定しています。

model = CNN(10)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

学習・検証処理の定義

ここからは実際に学習、検証を行う処理を実装していきます。
まずは、エポックごとの損失の値とaccuracyを保存しておくリストを用意します。

  • num_epocs: エポック数
  • losses: エポック毎の損失の値を格納するリスト(学習時)
  • accs: エポック毎のaccuracyの値を格納するリスト(学習時)
  • val_losses: エポック毎の損失の値を格納するリスト(検証時)
  • val_accs: エポック毎のaccuracyの値を格納するリスト(検証時)

学習

エポック数ループし、合計15回学習と検証を行っていきます。

  • running_loss: エポック毎の損失の値※ミニバッチのループ毎に計算が行われ、足し合わされるので最終的にミニバッチのループ数で割り、平均を求める。
  • running_acc: エポック毎のaccuracyの値※ミニバッチのループ毎に計算が行われ、足し合わされるので最終的にミニバッチのループ数で割り、平均を求める。

ミニバッチのループではtrain_dataloaderからミニバッチサイズ(32件)ずつデータが読み込まれ、損失の計算、重みの更新が行われます。
imgsには画像が、labelsには教師データが格納され、to(device)でGPUに送っています。
そして、optimizer.zero_grad()で毎回、勾配を初期化しています。

output = model(imgs)で予測を行い、loss = criterion(output, labels)で損失を求めています。loss.backward()では誤差逆伝播を行っています。
running_lossとrunning_accを求め、足し合わせています。optimizer.step()では重みの更新を行っています。

最後にミニバッチのループの数でrunning_lossとrunning_accを割り、1エポックでの損失とaccuracyを求めた上で、lossesとaccsに追加しています。

検証

検証の処理も基本的には学習の処理と同様です。検証時には重みの更新は不要ですので、勾配の初期化、誤差逆伝播、重みの更新を実施していない点が学習時との違いとなります。

各エポック学習と検証が完了したら、最後に学習時と検証時の損失の値とaccuracyを出力しています。

num_epocs = 15
losses = []
accs = []
val_losses = []
val_accs = []
for epoch in range(num_epocs):
    # 学習
    running_loss = 0.0
    running_acc = 0.0
    for imgs, labels in train_dataloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        running_loss += loss.item()
        pred = torch.argmax(output, dim=1)
        running_acc += torch.mean(pred.eq(labels).float())
        optimizer.step()
    running_loss /= len(train_dataloader)
    running_acc /= len(train_dataloader)
    losses.append(running_loss)
    accs.append(running_acc)
    
    # 検証
    val_running_loss = 0.0
    val_running_acc = 0.0
    for val_imgs, val_labels in validation_dataloader:
        val_imgs = val_imgs.to(device)
        val_labels = val_labels.to(device)
        val_output = model(val_imgs)
        val_loss = criterion(val_output, val_labels)
        val_running_loss += val_loss.item()
        val_pred = torch.argmax(val_output, dim=1)
        val_running_acc += torch.mean(val_pred.eq(val_labels).float())
    val_running_loss /= len(validation_dataloader)
    val_running_acc /= len(validation_dataloader)
    val_losses.append(val_running_loss)
    val_accs.append(val_running_acc)
    print("epoch: {}, loss: {}, acc: {}    " \
    "val_epoch: {}, val_loss: {}, val_acc: {}".format(epoch, running_loss, running_acc, epoch, val_running_loss, val_running_acc))

エポック毎の損失とaccuracyをプロットして確認

最後にエポック毎に損失とaccuracyがどのように変化しているか確認しておきます。

plt.style.use('ggplot')
plt.plot(losses, label='train loss')
plt.plot(val_losses, label='validation loss')
plt.legend()

train lossを見るとエポックが進むごとにうまく収束していることが分かります。一方、validation lossは3エポックを過ぎたあたりからほぼ横ばいとなっています。

plt.style.use('ggplot')
plt.plot(accs, label='train acc')
plt.plot(val_accs, label='validation acc')
plt.legend()

accuracyを見てもlossと同じようにtrainではエポックが進む毎に上昇しているにもかかわらず、validationでは途中から横ばいとなっています。※上の画像のtrain loss, validation loss はそれぞれtrain acc, validation accの誤りです。

lossとaccの変化から今回の学習では過学習を起こしていることが分かります。つまり、学習データに対してはうまくフィットできているが、未知のデータに対してはうまく予測できない状態となっています。

そこで次からはこの過学習を抑える方法としてデータ拡張(Data Augmentation)をご紹介します。

PyTorchにおけるデータ拡張(Data Augmentation)の実現方法

データ拡張について

複雑なモデルで学習を行うためには十分なデータ量が必要です。データ拡張(Data Augmentation)は既存のデータを水増しするテクニックです。学習させるデータの量を増やすことで上記で発生していた過学習を防げる可能性があります。

画像のデータ拡張には様々な方法があります。例えば以下のような手法が考えられます。

  • データを水平に反転する
  • 一定の角度回転する
  • 明度や色調を変更する
  • ズームイン、ズームアウトする
  • 背景色を変更する etc…

PyTorchでは、前処理に利用できる様々なデータ拡張の関数が用意されています。こちらを利用することで簡単にデータ拡張が実現できます。今回はデータをランダムに水平に反転するRandomHorizontalFlip、ランダムに色調を変更するColorJitter、ランダムに画像を回転するRandomRotationを利用します。

ここで一点注意点ですが、PyTorchにおけるデータ拡張では物理的に10,000件のデータが20,000件に増えることはありません。エポック毎にランダムにデータ拡張を適用したり、しなかったりすることで、実質的にデータ拡張が実現されることになります。

データ拡張の実装

ここからはデータ拡張の実装を行います。データ拡張の実装といってもPyTorchでは簡単に実現できます。以下の例では上記で実装したプログラムの一部を修正していきたいと思います。

transformを定義している箇所を変更して、学習時のみデータ拡張を適用するようにします。

変更前

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5,))
])

変更後

# 前処理(検証用)
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5,))
])

# 前処理(学習用)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.RandomRotation(10),                                
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5,))
])

続いて、データセットを読み込むんでいる箇所を変更します。各transformに上で定義したものを利用するように変更します。

変更前

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
validation_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

変更後

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
validation_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

これだけでデータ拡張を行う準備は完了です。再度、学習と検証を実施して、lossとaccの変化を確認してみましょう。

エポック毎の損失とaccuracyをプロットして確認

最後にエポック毎に損失とaccuracyがどのように変化しているか確認しておきます。

plt.style.use('ggplot')
plt.plot(losses, label='train loss')
plt.plot(val_losses, label='validation loss')
plt.legend()
plt.style.use('ggplot')
plt.plot(accs, label='train acc')
plt.plot(val_accs, label='validation acc')
plt.legend()

今回はエポックが進むにつれ、train loss, validation lossともにうまく収束していることが分かります。またaccも双方順調に上昇しています。このことから、データ拡張を行ったことで、過学習が抑えられることが確認できました。

これでCNNの実装は完了です。

まとめ

オススメの学習サービス

機械学習で用いる数学の知識を基礎から学びたい初心者の方にはUdemyの以下の動画がオススメです。

興味がある方は是非こちらもチェックしてください。

Udemy – 【キカガク流】人工知能・機械学習 脱ブラックボックス講座 – 初級編 –

Udemy – 【キカガク流】人工知能・機械学習 脱ブラックボックス講座 – 中級編 –

今回は、CNNについてご紹介しました。本記事を通して、CNNや畳み込み、Max Poolingがどのようなものかなんとなく理解いただけたのではないでしょうか。また、PyTorchを利用して、実際にCNNを実装する方法とデータ拡張を行う方法についてみていきました。CNNは深層学習で画像を扱う場合によく利用されますので、興味のある方は本記事を参考に実際にプログラミングを行っていただけると幸いです。

人気急上昇PyTorchで学ぶディープラーニング入門

2021年5月1日

文系でも分かるディープラーニングのための数学

2021年5月1日
AI学習サービス

AI・機械学習が学べるオススメのサービス3選

2021年7月22日