nnablaのサンプルコード

タグ:

SONYのニューラルネットワークライブラリhttps://nnabla.org/ja/について、メモなんぞを。

学習内容は 0.0の時は0.9、0.1の時は0.8…といった感じでy = 0.9 – xをかえすだけの動作です。

サンプルなので、学習と評価データはプログラムに埋め込んであります。

import sys

import numpy as np

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
import nnabla.logger as L

BATCH_SIZE = 10
LIST_X = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
LIST_Y = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]


def build(in_x, in_y):
    """ニューラルネットワークの作成
    """

    with nn.parameter_scope("affine1"):
        h = PF.affine(in_x, 2)
        h = F.relu(h)

    with nn.parameter_scope("affine2"):
        h = PF.affine(h, 2)
        h = F.relu(h)

    with nn.parameter_scope("affine3"):
        h = PF.affine(h, 1)
        f = F.relu(h)

    return f


def train(f, in_x, in_y, epoch=1000, modelfile="model.h5"):
    """作成したネットワークfをepoch回学習して、学習結果をmodelfile名で保存
    """

    h = F.squared_error(f, in_y)

    loss = F.mean(h)

    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    for n in range(epoch):

        in_x.d = np.reshape(np.array(LIST_X), (BATCH_SIZE, 1))
        in_y.d = np.reshape(np.array(LIST_Y), (BATCH_SIZE, 1))

        loss.forward()
        solver.zero_grad()
        loss.backward()
        solver.update()

        if n % 10 == 0:
            L.info("%8d : %0.2f" % (n, loss.d))

    nn.save_parameters(modelfile)


def inference(f, in_x, in_y, modelfile="model.h5"):
    """学習結果を保存したmodelfile名をネットワークに適用して、推論を行う
    """

    nn.load_parameters(modelfile)

    for v in LIST_X:
        in_x.d = np.reshape(np.array([v]), (1, 1))
        f.forward()
        L.info("%0.1f = %0.1f" % (v, f.d[0]))


def main():

    x = nn.Variable(shape=(BATCH_SIZE, 1))
    y = nn.Variable(shape=(BATCH_SIZE, 1))

    f = build(x, y)
    train(f, x, y)
    inference(f, x, y)


if __name__ == "__main__":
    main()

# [EOF]