FC2カウンター FPGAの部屋 カーブ、直線用白線間走行用畳み込みニューラルネットワーク8(学習)
FC2ブログ

FPGAやCPLDの話題やFPGA用のツールの話題などです。 マニアックです。 日記も書きます。

FPGAの部屋

FPGAの部屋の有用と思われるコンテンツのまとめサイトを作りました。Xilinx ISEの初心者の方には、FPGAリテラシーおよびチュートリアルのページをお勧めいたします。

カーブ、直線用白線間走行用畳み込みニューラルネットワーク8(学習)

カーブ、直線用白線間走行用畳み込みニューラルネットワーク7(テスト用データの作成)”の続き。

前回は、”カーブ、直線用白線間走行用畳み込みニューラルネットワーク2(本格的なデータ収集)”で収集したテスト用の画像を 1062 枚に増やした。そして、それを 白黒変換し、白線の部分だけを 56 ピクセル X 10 行の画像を一枚の画像から 25 個切り出して、MNISTデータ形式のテスト用画像ファイルとテスト用ラベル・ファイルを作成した。今回はこれで、トレーニング用とテスト用のデータセットがそろったので、Jupyter Notebook を使用して学習してみよう。

なお、”ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装”のPython コードを全面的に使わせて頂いている。畳み込みニューラルネットワークの層構成は、畳み込み層 - ReLU - プーリング層 - 全結合層 - ReLU - 全結合層 - SoftMax だった。

まずは、”白線追従走行用畳み込みニューラルネットワークの製作6(学習1)”を参考にして、dataset_straight をコピーしてdataset_curve フォルダを作成した。更に、ch7_straight をコピーして、ch7_curve フォルダを作成した。
curve_tracing_cnn_39_171218.png

dataset_curve フォルダの直線走行用のMNISTデータセットを削除して、”カーブ、直線用白線間走行用畳み込みニューラルネットワーク5(トレーニング用データの生成)”で作成した train_curve_run_image と train_curve_run_label をコピーした。
また、”カーブ、直線用白線間走行用畳み込みニューラルネットワーク7(テスト用データの作成)”で作成した test_curve_run_image と test_curve_run_label をコピーした。
curve_tracing_cnn_40_171218.png

白線追従走行用畳み込みニューラルネットワークの製作6(学習1)”の straight_dataset.py を curve_dataset.py に変更し、train_num と test_num を書き換えた。
curve_tracing_cnn_41_171218.png

次に、Jupyter Notebook を立ち上げて、deep-learning-from-scratch/ch07_curve に移動して、DLFS_Chap7_integer.ipynb を起動した。
curve_tracing_cnn_42_171218.png

DLFS_Chap7_integer.ipynb を編集して、

from dataset_curve.curve_dataset import load_mnist

に変更した。
curve_tracing_cnn_43_171218.png

Jupyter Notebook 上の train_convnet.py を貼っておく。畳み込み層のフィルタ数は 2 となっている。

# train_convnet.py
# 2017/08/08 白線追従走行用CNNに変更 by marsee
# 元になったコードは、https://github.com/oreilly-japan/deep-learning-from-scratch にあります。
# 改変したコードもMITライセンスとします。 2017/08/08 by marsee
# カーブ・データ用に修正 2017/12/18 by marsee

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset_curve.curve_dataset import load_mnist
from simple_convnet import SimpleConvNet
from common.trainer import Trainer

# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

all_x_train = x_train.shape[0];
all_x_test = x_test.shape[0];
print(all_x_train)
print(all_x_test)

# 処理に時間のかかる場合はデータを削減 
#x_train, t_train = x_train[:5000], t_train[:5000]
#x_test, t_test = x_test[:1000], t_test[:1000]

max_epochs = 20

network = SimpleConvNet(input_dim=(1,10,56), 
                        conv_param = {'filter_num': 2, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        #conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        hidden_size=100, output_size=3, weight_init_std=0.01)
                        
trainer = Trainer(network, x_train, t_train, x_test, t_test,
                  epochs=max_epochs, mini_batch_size=100,
                  optimizer='Adam', optimizer_param={'lr': 0.001},
                  evaluate_sample_num_per_epoch=100)
trainer.train()

# パラメータの保存
network.save_params("params.pkl")
print("Saved Network Parameters!")


これで、Jupyter Notebook 上の train_convnet.py で学習したところ、0.95472693032 の精度が確保できた。

train_convnet.py を実行した結果の最初と最後を貼っておく。

Converting train_curve_run_image to NumPy Array ...
Done
Converting train_curve_run_label to NumPy Array ...
Done
Converting test_curve_run_image to NumPy Array ...
Done
Converting test_curve_run_label to NumPy Array ...
Done
Creating pickle file ...
Done!
34650
26550
pool_output_size =156
train loss:1.0981899475
=== epoch:1, train acc:1.0, test acc:1.0 ===
train loss:1.0980374466
train loss:1.0966096818
train loss:1.09627088932
train loss:1.09695438209


train loss:0.12427020397
train loss:0.0725574218453
train loss:0.0593783238407
train loss:0.0805719249781
train loss:0.0529893267479
train loss:0.0247235638705
train loss:0.0481303088163
train loss:0.0534950709832
train loss:0.10404449012
train loss:0.082864550859
train loss:0.091640690661
train loss:0.102096129491
train loss:0.0732472519834
train loss:0.068286296823
train loss:0.0726208094243
train loss:0.0522118761663
=============== Final Test Accuracy ===============
test acc:0.95472693032
Saved Network Parameters!

  1. 2017年12月18日 04:43 |
  2. DNN
  3. | トラックバック:0
  4. | コメント:0

コメント

コメントの投稿


管理者にだけ表示を許可する

トラックバック URL
http://marsee101.blog.fc2.com/tb.php/4007-9178e6d1
この記事にトラックバックする(FC2ブログユーザー)