FC2カウンター FPGAの部屋 keras_compressor のモデルをVivado HLSで実装する1(model_raw.h5)
FC2ブログ

FPGAやCPLDの話題やFPGA用のツールの話題などです。 マニアックです。 日記も書きます。

FPGAの部屋

FPGAの部屋の有用と思われるコンテンツのまとめサイトを作りました。Xilinx ISEの初心者の方には、FPGAリテラシーおよびチュートリアルのページをお勧めいたします。

keras_compressor のモデルをVivado HLSで実装する1(model_raw.h5)

keras_compressor を試してみる3”までで作成できた畳み込みニューラルネットワークのモデルをVivado HLS のこれまでのスキームで実装してみようと思う。

ホストUbuntu 18.04 に 3 つのモデルをコピーして、まずは大きいので実装はできないと思う圧縮前のモデルのパラメータを見てみよう。

まずは、Docker コンテナからホストPC へモデルをコピーする。
docker ps
でコンテナID を確認して、3 つのモデルを docker cp コマンドでコピーする。
docker cp 7552a77f8bdb:/srv/keras_compressor/example/mnist/model_raw.h5 .
docker cp 7552a77f8bdb:/srv/keras_compressor/example/mnist/model_compressed.h5 .
docker cp 7552a77f8bdb:/srv/keras_compressor/example/mnist/model_finetuned.h5 .

keras_compressor_24_190302.png

ホストPC の ~/Docker/keras_compressor/ ディレクトリの下に keras_compressor を git clone した。
git clone https://github.com/nico-opendata/keras_compressor

ホストPC でJupyter Notebook を起動して、Python3 の新しいノートを作成した(model_check_raw.ipynb)。
model.raw5.h5 の情報を見ていこう。
まずは、keras_compressor/example/mnist/train.py から def model_def を引用する。

# DwangoMediaVillage/keras_compressor/example/mnist/train.py から一部引用
# https://github.com/DwangoMediaVillage/keras_compressor/blob/master/example/mnist/train.py

from keras import backend as K
from keras.callbacks import EarlyStopping
from keras.datasets import mnist
from keras.layers import Conv2D, Dense, Dropout, Flatten, Input, MaxPool2D
from keras.models import Model
from keras.utils.np_utils import to_categorical

def gen_model():
    # from keras mnist tutorial
    img_input = Input(shape=(img_rows, img_cols, 1))

    h = img_input
    h = Conv2D(32, (3, 3), activation='relu')(h)
    h = Dropout(0.25)(h)
    h = Conv2D(64, (3, 3), activation='relu')(h)
    h = MaxPool2D((2, 2))(h)
    h = Dropout(0.25)(h)

    h = Flatten()(h)

    h = Dense(128, activation='relu')(h)
    h = Dropout(0.5)(h)
    h = Dense(class_num, activation='softmax')(h)

    model = Model(img_input, h)
    return model


model_raw.h5 を読み込む。

# 学習済みモデルの読み込み

from keras.models import load_model
import sys, os
sys.path.append("keras_compressor/keras_compressor")
from layers import custom_layers

model_compressed = load_model('model_compressed.h5', custom_layers)


重みとバイアスの再ロード用モジュールを実行する。

# My Mnist CNN(重みとバイアスの再ロード用モジュール)
# Conv2D - ReLU - MaxPooling - Dence - ReLU - Dence
# 2018/05/25 by marsee
# Keras / Tensorflowで始めるディープラーニング入門 https://qiita.com/yampy/items/706d44417c433e68db0d
# のPythonコードを再利用させて頂いている

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

img_rows, img_cols = 28, 28

(x_train, y_train), (x_test, y_test) = mnist.load_data()

#Kerasのバックエンドで動くTensorFlowとTheanoでは入力チャンネルの順番が違うので場合分けして書いています
if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 
x_test /= 


重みを表示する。

model_raw_list = model_raw.get_weights()
print(model_raw_list)


結果を示す。

[array([[[[-0.05172705, -0.15345924,  0.00077587, -0.03371203,
          -0.16821936,  0.12922178, -0.13418826, -0.10867919,
          -0.1458925 ,  0.19463108, -0.14580294, -0.03869295,
           0.10389572,  0.17042054,  0.0100111 ,  0.12097651,
          -0.4023658 , -0.20207658, -0.07195384,  0.03934375,
          -0.11012892, -0.07478225,  0.1226447 ,  0.19403256,
           0.01812533, -0.05677873,  0.17626294, -0.03892957,
          -0.00537937, -0.05291621, -0.16475417, -0.00706133]],

        [[ 0.13856724, -0.15383671, -0.08084076, -0.13828416,
          -0.07712211,  0.04172476, -0.16041471, -0.08714657,
          -0.17163298, -0.03662995, -0.0620219 ,  0.10431749,
          -0.12399048, -0.00673226, -0.03650982, -0.0106194 ,
          -0.18041939, -0.15767479, -0.11838219,  0.10532779,
           0.03665911,  0.02992007,  0.00675609,  0.03340556,
           0.05760363,  0.0406777 , -0.05025062,  0.03133169,
           0.09469289,  0.19778591,  0.04521935,  0.11820332]],

        [[-0.01835519,  0.10497978, -0.08119331, -0.04822454,
           0.1015288 , -0.05404653, -0.20264348,  0.01448757,
           0.05472392,  0.18080442,  0.08911949,  0.15507902,
          -0.14813055,  0.03455162, -0.16724327, -0.13468543,
          -0.08864182, -0.02936704,  0.01190247,  0.11374157,
           0.13765548, -0.06801588, -0.16707923, -0.05098806,
           0.16587973, -0.1300925 ,  0.07113007, -0.17776224,
           0.14063434,  0.18423977, -0.14625272,  0.17271286]]],


       [[[-0.05083156, -0.13264939,  0.04466102,  0.03564763,
           0.03559402,  0.13187517,  0.03594281,  0.05526815,
          -0.11275288,  0.0039167 , -0.07657869, -0.03868581,
           0.12851602,  0.1521333 ,  0.0513791 , -0.12931691,
          -0.09383278, -0.1648976 ,  0.07744151, -0.07568712,
          -0.05860967, -0.05193807,  0.12811543, -0.0282612 ,
           0.07773264,  0.06226772,  0.02389823,  0.00214993,
          -0.12375973, -0.00676751, -0.02455793, -0.03462191]],

        [[-0.03594301,  0.11407857,  0.14229886,  0.04625476,
           0.05015176, -0.05551488, -0.02045105, -0.03296088,
           0.09474239, -0.06525583, -0.01278242,  0.02829993,
          -0.07840222,  0.04641729, -0.03488041,  0.05656511,
           0.13930908, -0.03737781, -0.09818485, -0.09107177,
           0.07095284,  0.06994601,  0.06080497, -0.13724516,
          -0.03498914, -0.11904315,  0.14269935, -0.04210183,
           0.0651255 ,  0.10612902,  0.11397902, -0.03795255]],

        [[ 0.17533973,  0.05431411, -0.08569562,  0.10862813,
          -0.06807359, -0.17448065, -0.0100623 ,  0.10232493,
           0.0435214 ,  0.01783775,  0.02181555,  0.12274784,
          -0.05855373, -0.05113681,  0.03346287,  0.1236548 ,
           0.07271655,  0.11179519,  0.01439611, -0.011888  ,
          -0.02642842, -0.1024743 ,  0.00944871, -0.01964488,
          -0.02744397, -0.11693428,  0.05274292, -0.03606737,
           0.09729386, -0.12262847, -0.07408664, -0.07936568]]],


       [[[-0.2106485 , -0.07639346, -0.10960838,  0.1468544 ,
           0.02582365,  0.10072139,  0.12636843,  0.10368153,
          -0.01542548,  0.00699573,  0.04627146, -0.01649213,
           0.13230959, -0.08335535,  0.05959431, -0.11145467,
          -0.07102275, -0.1388285 ,  0.05205231, -0.09978414,
          -0.03917275,  0.10191721, -0.10038435, -0.12760422,
          -0.2619083 ,  0.10844929, -0.13908163,  0.03273428,
           0.02183258,  0.00879611,  0.27741808, -0.1739349 ]],

        [[-0.19153525, -0.07415977,  0.1546851 , -0.0031653 ,
           0.07678907,  0.05759468,  0.0014401 ,  0.04860839,
           0.1039779 , -0.21813264,  0.07777672, -0.1257387 ,
           0.05847128, -0.09520132,  0.02177527,  0.00469245,
           0.06917801, -0.04790248,  0.11820323,  0.06097089,
           0.07881823,  0.14582054, -0.04223583, -0.17602372,
          -0.11885098,  0.0220329 , -0.006604  ,  0.04856543,
          -0.02801645, -0.01229255,  0.02585949, -0.18842506]],

        [[-0.07175326,  0.12322589,  0.00127919,  0.00317867,
          -0.09002248, -0.1933257 ,  0.1461138 , -0.01340641,
           0.09414906, -0.1098635 ,  0.05602227, -0.08832221,
          -0.00852111,  0.11656286,  0.15776785,  0.15097694,
           0.20661522,  0.02434064, -0.03323808,  0.0474364 ,
          -0.07273006,  0.03282277,  0.05005337,  0.03691453,
          -0.24776074,  0.13606238, -0.19123326,  0.13436365,
          -0.02396554, -0.22014052, -0.09753748, -0.02009463]]]],
      dtype=float32), array([ 0.03587808, -0.0171744 ,  0.00955198,  0.01177248,  0.00439308,
        0.0267545 ,  0.06779006, -0.02031287,  0.00851296,  0.08016433,
       -0.00722928,  0.00845261,  0.04088434, -0.00948897, -0.03233781,
       -0.04233009,  0.02995742,  0.03874496,  0.01290099, -0.0231648 ,
       -0.03201413, -0.01488906,  0.01443897,  0.02196893,  0.07889408,
        0.02526813,  0.01673093, -0.01149702, -0.04930302, -0.01055223,
        0.05971717,  0.07986403], dtype=float32), array([[[[ 0.01784169,  0.07252679, -0.01097492, ...,  0.0018453 ,
          -0.14215174,  0.05041484],
         [ 0.00849118,  0.12599792, -0.04117574, ...,  0.02723508,
          -0.10625021,  0.06377354],
         [ 0.1441614 ,  0.08159676,  0.10669465, ..., -0.07041598,
          -0.13106433, -0.08898731],
         ...,
         [ 0.05605402, -0.03201892, -0.06403774, ...,  0.08875347,
          -0.34901655,  0.01745879],
         [ 0.07731646, -0.16763942, -0.00132929, ...,  0.06760995,
          -0.11536116, -0.03162678],
         [-0.21285664, -0.10200371, -0.07402381, ..., -0.06327637,
           0.10867406,  0.16707096]],

        [[-0.2639166 , -0.03229538,  0.02680427, ..., -0.06278129,
          -0.0928833 , -0.0907499 ],
         [-0.13494554,  0.15743831, -0.07642376, ...,  0.1670157 ,
          -0.09430328,  0.05296668],
         [ 0.0333319 ,  0.03267855,  0.05561309, ...,  0.0517549 ,
          -0.06607336,  0.02380313],
         ...,
         [-0.08460605, -0.00817327, -0.05458642, ..., -0.0304523 ,
          -0.2983937 ,  0.0888538 ],
         [ 0.01115145, -0.12597817,  0.06745421, ..., -0.05120533,
          -0.03551681,  0.0256981 ],
         [-0.0872641 , -0.02846055, -0.00096143, ..., -0.00637156,
           0.05518444, -0.05756345]],

        [[-0.22022615,  0.02097548, -0.11983979, ..., -0.11052338,
          -0.09645115, -0.12325586],
         [ 0.0753931 ,  0.0689197 ,  0.07239848, ...,  0.01090579,
          -0.3111424 , -0.07718782],
         [-0.07048219, -0.03452173, -0.02032233, ..., -0.01574312,
           0.12788951, -0.07588881],
         ...,
         [-0.09354184,  0.17358212, -0.14240177, ...,  0.02514045,
          -0.09468765, -0.01826777],
         [-0.08374058,  0.04152744, -0.00278586, ...,  0.1042031 ,
          -0.02293781,  0.0224688 ],
         [-0.11540057,  0.09110897, -0.00968612, ..., -0.01108299,
           0.1154867 , -0.00337872]]],


       [[[-0.22131965,  0.06954905,  0.02248521, ..., -0.13253921,
          -0.05250412,  0.085272  ],
         [-0.04595165,  0.10061093,  0.04049781, ..., -0.05506902,
           0.06715798,  0.04935758],
         [ 0.00560335, -0.13551575,  0.11727943, ...,  0.04368603,
          -0.07839978, -0.07421607],
         ...,
         [ 0.09735671, -0.00297785, -0.01820614, ...,  0.05630978,
          -0.10141616, -0.05311462],
         [ 0.10309629, -0.3161246 , -0.01552058, ..., -0.05780705,
          -0.19297424, -0.1756806 ],
         [-0.07637326,  0.07960378,  0.12787989, ..., -0.07312059,
           0.00733983,  0.000507  ]],

        [[-0.24390881,  0.03017478, -0.08301172, ..., -0.07392175,
          -0.12365779,  0.0033332 ],
         [-0.09552263,  0.05114197,  0.10412158, ...,  0.06717188,
          -0.00373708,  0.02406436],
         [ 0.05042576,  0.06596102,  0.03903994, ..., -0.0583111 ,
           0.08278285,  0.03121348],
         ...,
         [ 0.03345596,  0.09832835, -0.02513286, ...,  0.04554279,
          -0.15410991, -0.04104587],
         [ 0.05118542, -0.3417036 ,  0.02022707, ..., -0.00925362,
           0.03248382,  0.08202723],
         [-0.09507918, -0.08170375,  0.02168931, ..., -0.10024032,
           0.02402998, -0.21379665]],

        [[-0.13424525, -0.02660633, -0.1192525 , ..., -0.23914535,
          -0.28084764, -0.05949372],
         [-0.19278674, -0.13323346,  0.07695939, ...,  0.00986434,
          -0.12384682, -0.13519521],
         [ 0.02250148, -0.13858531, -0.04828302, ...,  0.01930327,
           0.16163033, -0.11828827],
         ...,
         [-0.11395966,  0.0870876 , -0.10622517, ...,  0.04610986,
          -0.1493755 , -0.27299204],
         [-0.10000117, -0.12624545, -0.01311766, ..., -0.02880889,
           0.01535376, -0.08321954],
         [-0.08176904,  0.03700143,  0.03589568, ..., -0.11301193,
           0.06526493, -0.14029837]]],


       [[[-0.16750205,  0.12537089,  0.16480623, ..., -0.12945914,
          -0.0090812 ,  0.08250518],
         [ 0.0551733 , -0.14380467, -0.13980682, ..., -0.04423543,
           0.13000865,  0.04397005],
         [-0.03908797, -0.3436229 , -0.08763045, ..., -0.0439879 ,
          -0.18065412, -0.0766596 ],
         ...,
         [ 0.08520151,  0.0185745 ,  0.00787074, ..., -0.00529092,
          -0.02693314, -0.07193764],
         [-0.05285215, -0.14865172, -0.0673898 , ...,  0.0778336 ,
           0.06110919, -0.03228816],
         [-0.11807112,  0.05031039,  0.09319225, ..., -0.07421203,
          -0.02937514,  0.04078325]],

        [[-0.1349165 , -0.04630463,  0.0185292 , ..., -0.14944914,
           0.21822254,  0.09291119],
         [-0.17330258, -0.2771658 ,  0.03532882, ...,  0.11392233,
           0.08581509,  0.05589819],
         [-0.08194944, -0.3013157 , -0.11831472, ..., -0.03284019,
          -0.02248873,  0.04296595],
         ...,
         [-0.03969822,  0.06953567, -0.03289285, ..., -0.05971542,
           0.07247194, -0.03298815],
         [-0.03705103, -0.15690652, -0.1390137 , ...,  0.1131255 ,
           0.02823978,  0.01641914],
         [-0.02423166,  0.12832679,  0.03331943, ..., -0.2086502 ,
           0.02671161,  0.00349825]],

        [[ 0.03606407,  0.00960674, -0.06802542, ..., -0.22029497,
           0.16772458,  0.03192182],
         [-0.11913268, -0.2662513 , -0.06664034, ..., -0.12726213,
          -0.04418406, -0.20206876],
         [-0.0651824 , -0.16072227,  0.04822908, ...,  0.09448915,
           0.0536258 ,  0.09910108],
         ...,
         [-0.2386213 , -0.02251608, -0.14554805, ...,  0.04479846,
          -0.01842348, -0.1160126 ],
         [-0.13344261, -0.11262694, -0.08565197, ...,  0.16454437,
          -0.10231322, -0.11221529],
         [-0.06490589, -0.01683243,  0.02383618, ..., -0.11094074,
           0.07580288, -0.04840912]]]], dtype=float32), array([-0.03850411, -0.01573794,  0.01115942,  0.00562205, -0.03604732,
        0.01286564, -0.09528935, -0.04337158, -0.01698404, -0.05879769,
       -0.03624293, -0.08855402, -0.04848624, -0.03195901, -0.08869974,
        0.00507461, -0.07211261, -0.05120531, -0.07258753, -0.02589275,
        0.00240357, -0.03014836, -0.07204135, -0.01561245, -0.03739467,
       -0.04205899, -0.06722149,  0.00948338, -0.03016038, -0.02061644,
       -0.05402645,  0.14428563, -0.05207796, -0.05976514, -0.04870901,
        0.00583873, -0.00821861, -0.0367915 , -0.04091152, -0.0452308 ,
        0.01322065, -0.0670034 , -0.03599038, -0.06793729, -0.07481442,
       -0.06112335, -0.00378744, -0.06627614, -0.07054048,  0.02497522,
       -0.03339541, -0.02779962, -0.06606241, -0.0774965 , -0.02701756,
       -0.02294137, -0.06889407, -0.03851727, -0.01658753,  0.01253654,
       -0.06667125, -0.0335856 ,  0.00278443, -0.04023726], dtype=float32), array([[ 0.00470384, -0.01098892,  0.0251282 , ..., -0.02190203,
         0.01992498, -0.00856046],
       [-0.00689279,  0.00944513, -0.01233803, ...,  0.02524373,
        -0.02891207, -0.00187604],
       [-0.00194416,  0.03971022,  0.03803337, ...,  0.01431566,
         0.00224361,  0.02952734],
       ...,
       [ 0.03151323, -0.00314217, -0.00652592, ..., -0.02395282,
         0.0161369 ,  0.02760857],
       [-0.04572763,  0.05457359,  0.00654751, ..., -0.00405126,
        -0.00480561, -0.05019789],
       [ 0.01498569,  0.0050351 ,  0.04433047, ..., -0.00691412,
         0.03109068, -0.01302621]], dtype=float32), array([-0.06168257, -0.1514682 , -0.1322928 , -0.10744902, -0.08265022,
       -0.05077443, -0.06567107, -0.09612435, -0.16938643, -0.08640631,
       -0.17030963, -0.0943355 , -0.00991632, -0.08720349, -0.12731756,
       -0.07198199, -0.11136244, -0.08368985, -0.13953699, -0.131075  ,
       -0.08466823, -0.00631496, -0.17218405, -0.06344838, -0.17180654,
       -0.14032175, -0.07136092, -0.09621119, -0.00660794, -0.01606259,
       -0.09382402, -0.04687085, -0.10574532, -0.16411512, -0.01561454,
       -0.06328947, -0.15161262, -0.08436941, -0.10497447, -0.12699552,
       -0.12868118, -0.09164497, -0.11735129, -0.17970178, -0.11265838,
       -0.076789  , -0.11856248, -0.06805629, -0.11970215, -0.10988984,
       -0.01188129, -0.11470675, -0.11072908, -0.0443239 , -0.13116506,
       -0.13980559, -0.11000474,  0.00931341, -0.06132468, -0.12936714,
       -0.11785036, -0.14282347, -0.11910468, -0.16563804, -0.1154892 ,
       -0.06983681, -0.09040678, -0.11911937, -0.13075021, -0.07354662,
       -0.17754656, -0.03138041, -0.10065349, -0.10409548, -0.08683839,
       -0.09837727, -0.04784932, -0.00906859, -0.01060055, -0.10975196,
       -0.11331895, -0.14100014, -0.13861252, -0.09429768, -0.1434664 ,
       -0.07264388, -0.03577153, -0.00947432, -0.05689292, -0.17511971,
       -0.00466694, -0.14445004, -0.1057132 , -0.0294685 , -0.14415818,
       -0.09408505, -0.05248842, -0.07065452, -0.12279639, -0.04660085,
       -0.15480113, -0.17639175, -0.0758699 , -0.10553791, -0.05692659,
       -0.08024313, -0.1117903 , -0.11112501, -0.04388554, -0.08886621,
       -0.15003382, -0.09736416, -0.12964201, -0.06801744, -0.08971801,
       -0.11670808, -0.12153988, -0.00981606, -0.09409956, -0.10864237,
       -0.13626133, -0.05466016, -0.11123326, -0.07980989, -0.16771536,
       -0.10403304, -0.1670245 , -0.21162453], dtype=float32), array([[-0.39988974, -0.12601139, -0.46622646, ..., -0.44695276,
        -0.46335962,  0.13559648],
       [ 0.07339579, -0.45668814, -0.333825  , ..., -0.7270869 ,
         0.17933416, -0.5707154 ],
       [ 0.06677974, -0.5544925 ,  0.1668607 , ..., -0.46047652,
        -0.4375593 , -0.05378945],
       ...,
       [-0.05012963, -0.112903  , -0.1406179 , ...,  0.16494776,
         0.20007576, -0.21182767],
       [-0.5350573 , -0.40777698,  0.19234389, ...,  0.19444552,
        -0.33254635, -0.54401803],
       [-0.52111495, -0.26989737,  0.20298809, ...,  0.08232445,
         0.14694539, -0.36222363]], dtype=float32), array([ 0.04120593,  0.23183034,  0.00655218, -0.07728817, -0.09058365,
       -0.0624351 , -0.14979938, -0.04618856, -0.06826726, -0.0130224 ],
      dtype=float32)]


model_raw.summary()


結果を示す。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 26, 26, 32)        320       
_________________________________________________________________
dropout_1 (Dropout)          (None, 26, 26, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 12, 12, 64)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 9216)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               1179776   
_________________________________________________________________
dropout_3 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290      
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
_________________________________________________________________


model.summary() で取得した各層の情報を元に畳み込み層の中間出力を取り出す。

# Convolution layerの中間出力を取り出す 
from keras.models import Model

conv_layer_name = 'conv2d_1'

conv_layer = model_raw.get_layer(conv_layer_name)
conv_layer_wb = conv_layer.get_weights()

conv_layer_model = Model(inputs=model_raw.input,
                                 outputs=model_raw.get_layer(conv_layer_name).output)
conv_output = conv_layer_model.predict(x_test, verbose=1)


結果を示す。

10000/10000 [==============================] - 2s 227us/step


conv1 の重みやバイアスの配列の構成を見てみよう。

conv_layer_weight = conv_layer_wb[0]
conv_layer_bias = conv_layer_wb[1]

print(conv_layer_weight.shape)
print(conv_layer_weight.T.shape)
print(conv_layer_bias.shape)


結果を示す。

(3, 3, 1, 32)
(32, 1, 3, 3)
(32,)


重みの値を示す。

print("conv_layer_weight.T = {0}".format(conv_layer_weight.T))


結果を示す。

conv_layer_weight.T = [[[[-0.05172705 -0.05083156 -0.2106485 ]
   [ 0.13856724 -0.03594301 -0.19153525]
   [-0.01835519  0.17533973 -0.07175326]]]


 [[[-0.15345924 -0.13264939 -0.07639346]
   [-0.15383671  0.11407857 -0.07415977]
   [ 0.10497978  0.05431411  0.12322589]]]


 [[[ 0.00077587  0.04466102 -0.10960838]
   [-0.08084076  0.14229886  0.1546851 ]
   [-0.08119331 -0.08569562  0.00127919]]]


 [[[-0.03371203  0.03564763  0.1468544 ]
   [-0.13828416  0.04625476 -0.0031653 ]
   [-0.04822454  0.10862813  0.00317867]]]


 [[[-0.16821936  0.03559402  0.02582365]
   [-0.07712211  0.05015176  0.07678907]
   [ 0.1015288  -0.06807359 -0.09002248]]]


 [[[ 0.12922178  0.13187517  0.10072139]
   [ 0.04172476 -0.05551488  0.05759468]
   [-0.05404653 -0.17448065 -0.1933257 ]]]


 [[[-0.13418826  0.03594281  0.12636843]
   [-0.16041471 -0.02045105  0.0014401 ]
   [-0.20264348 -0.0100623   0.1461138 ]]]


 [[[-0.10867919  0.05526815  0.10368153]
   [-0.08714657 -0.03296088  0.04860839]
   [ 0.01448757  0.10232493 -0.01340641]]]


 [[[-0.1458925  -0.11275288 -0.01542548]
   [-0.17163298  0.09474239  0.1039779 ]
   [ 0.05472392  0.0435214   0.09414906]]]


 [[[ 0.19463108  0.0039167   0.00699573]
   [-0.03662995 -0.06525583 -0.21813264]
   [ 0.18080442  0.01783775 -0.1098635 ]]]


 [[[-0.14580294 -0.07657869  0.04627146]
   [-0.0620219  -0.01278242  0.07777672]
   [ 0.08911949  0.02181555  0.05602227]]]


 [[[-0.03869295 -0.03868581 -0.01649213]
   [ 0.10431749  0.02829993 -0.1257387 ]
   [ 0.15507902  0.12274784 -0.08832221]]]


 [[[ 0.10389572  0.12851602  0.13230959]
   [-0.12399048 -0.07840222  0.05847128]
   [-0.14813055 -0.05855373 -0.00852111]]]


 [[[ 0.17042054  0.1521333  -0.08335535]
   [-0.00673226  0.04641729 -0.09520132]
   [ 0.03455162 -0.05113681  0.11656286]]]


 [[[ 0.0100111   0.0513791   0.05959431]
   [-0.03650982 -0.03488041  0.02177527]
   [-0.16724327  0.03346287  0.15776785]]]


 [[[ 0.12097651 -0.12931691 -0.11145467]
   [-0.0106194   0.05656511  0.00469245]
   [-0.13468543  0.1236548   0.15097694]]]


 [[[-0.4023658  -0.09383278 -0.07102275]
   [-0.18041939  0.13930908  0.06917801]
   [-0.08864182  0.07271655  0.20661522]]]


 [[[-0.20207658 -0.1648976  -0.1388285 ]
   [-0.15767479 -0.03737781 -0.04790248]
   [-0.02936704  0.11179519  0.02434064]]]


 [[[-0.07195384  0.07744151  0.05205231]
   [-0.11838219 -0.09818485  0.11820323]
   [ 0.01190247  0.01439611 -0.03323808]]]


 [[[ 0.03934375 -0.07568712 -0.09978414]
   [ 0.10532779 -0.09107177  0.06097089]
   [ 0.11374157 -0.011888    0.0474364 ]]]


 [[[-0.11012892 -0.05860967 -0.03917275]
   [ 0.03665911  0.07095284  0.07881823]
   [ 0.13765548 -0.02642842 -0.07273006]]]


 [[[-0.07478225 -0.05193807  0.10191721]
   [ 0.02992007  0.06994601  0.14582054]
   [-0.06801588 -0.1024743   0.03282277]]]


 [[[ 0.1226447   0.12811543 -0.10038435]
   [ 0.00675609  0.06080497 -0.04223583]
   [-0.16707923  0.00944871  0.05005337]]]


 [[[ 0.19403256 -0.0282612  -0.12760422]
   [ 0.03340556 -0.13724516 -0.17602372]
   [-0.05098806 -0.01964488  0.03691453]]]


 [[[ 0.01812533  0.07773264 -0.2619083 ]
   [ 0.05760363 -0.03498914 -0.11885098]
   [ 0.16587973 -0.02744397 -0.24776074]]]


 [[[-0.05677873  0.06226772  0.10844929]
   [ 0.0406777  -0.11904315  0.0220329 ]
   [-0.1300925  -0.11693428  0.13606238]]]


 [[[ 0.17626294  0.02389823 -0.13908163]
   [-0.05025062  0.14269935 -0.006604  ]
   [ 0.07113007  0.05274292 -0.19123326]]]


 [[[-0.03892957  0.00214993  0.03273428]
   [ 0.03133169 -0.04210183  0.04856543]
   [-0.17776224 -0.03606737  0.13436365]]]


 [[[-0.00537937 -0.12375973  0.02183258]
   [ 0.09469289  0.0651255  -0.02801645]
   [ 0.14063434  0.09729386 -0.02396554]]]


 [[[-0.05291621 -0.00676751  0.00879611]
   [ 0.19778591  0.10612902 -0.01229255]
   [ 0.18423977 -0.12262847 -0.22014052]]]


 [[[-0.16475417 -0.02455793  0.27741808]
   [ 0.04521935  0.11397902  0.02585949]
   [-0.14625272 -0.07408664 -0.09753748]]]


 [[[-0.00706133 -0.03462191 -0.1739349 ]
   [ 0.11820332 -0.03795255 -0.18842506]
   [ 0.17271286 -0.07936568 -0.02009463]]]]


バイアスの値を示す。

print("conv_layer_bias = {0}".format(conv_layer_bias))


結果を示す。

conv_layer_bias = [ 0.03587808 -0.0171744   0.00955198  0.01177248  0.00439308  0.0267545
  0.06779006 -0.02031287  0.00851296  0.08016433 -0.00722928  0.00845261
  0.04088434 -0.00948897 -0.03233781 -0.04233009  0.02995742  0.03874496
  0.01290099 -0.0231648  -0.03201413 -0.01488906  0.01443897  0.02196893
  0.07889408  0.02526813  0.01673093 -0.01149702 -0.04930302 -0.01055223
  0.05971717  0.07986403]


統計情報を取得する。

import numpy as np

print("np.max(conv_layer_weight) = {0}".format(np.max(conv_layer_weight)))
print("np.min(conv_layer_weight) = {0}".format(np.min(conv_layer_weight)))
abs_conv_layer_weight = np.absolute(conv_layer_weight)
print("np.max(abs_conv_layer_weight) = {0}".format(np.max(abs_conv_layer_weight)))
print("np.min(abs_conv_layer_weight) = {0}".format(np.min(abs_conv_layer_weight)))

print("np.max(conv_layer_bias) = {0}".format(np.max(conv_layer_bias)))
print("np.min(conv_layer_bias) = {0}".format(np.min(conv_layer_bias)))
abs_conv_layer_bias = np.absolute(conv_layer_bias)
print("np.max(abs_conv_layer_bias) = {0}".format(np.max(abs_conv_layer_bias)))
print("np.min(abs_conv_layer_bias) = {0}".format(np.min(abs_conv_layer_bias)))

print("conv_output = {0}".format(conv_output.shape))
print("np.std(conv_output) = {0}".format(np.std(conv_output)))
print("np.max(conv_output) = {0}".format(np.max(conv_output)))
print("np.min(conv_output) = {0}".format(np.min(conv_output)))

abs_conv_output = np.absolute(conv_output)
print("np.max(abs_conv) = {0}".format(np.max(abs_conv_output)))
print("np.min(abs_conv) = {0}".format(np.min(abs_conv_output)))


結果を示す。

np.max(conv_layer_weight) = 0.2774180769920349
np.min(conv_layer_weight) = -0.4023658037185669
np.max(abs_conv_layer_weight) = 0.4023658037185669
np.min(abs_conv_layer_weight) = 0.0007758717983961105
np.max(conv_layer_bias) = 0.08016432821750641
np.min(conv_layer_bias) = -0.04930301755666733
np.max(abs_conv_layer_bias) = 0.08016432821750641
np.min(abs_conv_layer_bias) = 0.0043930779211223125
conv_output = (10000, 26, 26, 32)
np.std(conv_output) = 0.05693778768181801
np.max(conv_output) = 0.5177762508392334
np.min(conv_output) = 0.0
np.max(abs_conv) = 0.5177762508392334
np.min(abs_conv) = 0.0


畳み込み層 1 層目の重みのグラフ。

%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt

# Convolution layerのweightのグラフ
conv_layer_weight_f = conv_layer_weight.flatten()
plt.plot(conv_layer_weight_f)
plt.title('conv_layer_weight')
plt.show()


keras_compressor_32_190302.png

畳み込み層 1 層目のバイアスのグラフ。

# Convolution layerのbiasのグラフ
conv_layer_bias_f = conv_layer_bias.flatten()
plt.plot(conv_layer_bias_f)
plt.title('conv_layer_bias')
plt.show()


keras_compressor_33_190302.png

Jupyter Notebook の画像を貼っておく。
keras_compressor_25_190302.png
keras_compressor_26_190302.png
keras_compressor_27_190302.png
keras_compressor_28_190302.png
keras_compressor_29_190302.png
keras_compressor_30_190302.png
keras_compressor_31_190302.png
  1. 2019年03月02日 20:59 |
  2. DNN
  3. | トラックバック:0
  4. | コメント:0

コメント

コメントの投稿


管理者にだけ表示を許可する

トラックバック URL
http://marsee101.blog.fc2.com/tb.php/4459-c020d866
この記事にトラックバックする(FC2ブログユーザー)