やまメモ

やまげんによるやまげんのためのメモ

やっとChainerが動いた!!

一日調べまくってやっとChainerが動きました.

ソースコードはこちら↴

# -*- coding: utf-8 -*-
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions


class Model(Chain):

    def __init__(self, n_units, n_out):
        super(Model, self).__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(None, n_units),
            l3=L.Linear(None, n_units),
            l4=L.Linear(None, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        h3 = F.relu(self.l3(h2))
        return F.relu(self.l4(h3))



def main():

    epoch = 200

    model = Model(64, 1)
    optimizer = optimizers.Adam()
    optimizer.use_cleargrads()
    optimizer.setup(model)

    is_gpu = True
    if(is_gpu):
        gpu_device = 0
        cuda.get_device(gpu_device).use()
        model.to_gpu(gpu_device)
        xp = cuda.cupy
    else:
        xp = np

    data = []
    for i in range(1, 10):
        for j in range(1, 10):
            data.append([xp.array([i, j], dtype=xp.float32),
                         xp.array([i * j], xp.float32)])

    loss = None
    for i in range(epoch):
        for x, t in data:
            model.cleargrads()
            y = model(Variable(x.reshape((1, 2))))
            t = Variable(t.reshape((1, 1)))
            loss = F.mean_squared_error(y, t)
            loss.backward()
            optimizer.update()
        print("epoch = ", i, " / ", end="")
        print("loss = ", loss.data)

    # テスト
    flag = True
    while(flag):
        print("数値: ")
        one = int(input())
        two = int(input())
        d = xp.array([one, two], dtype=xp.float32)
        answer = model(Variable(d.reshape((1, 2))))
        print("答え = ", answer.data)


if __name__ == '__main__':
    main()

入力に掛け算の数字([1, 1], [1, 2], …, [6, 7], …, [9, 9])
正解データにその答え([1], [2], …, [42], …, [81])
を与えている.