batch, epochについて学ぶ
1回のパラメータ更新のために利用するデータの数をbatch_size,
データを何周するかをepochと呼ぶみたいです
前回は簡単のためbatch_size = 1, epochs = 1 としましたが、現実的には両方1というのはあまりないと思うので前回のプログラムを修正してbatch_sizeとepochsに対応できるようにします。
dynamic_rnnのinputは [batch_size, max_len, input_size]なshapeなので、batch_size = 5, max_len = 3, input_size = 1なら
xs = [ [[1], [2], [3]], [[4], [5], [6]], [[7], [8], [9]], [[0], [0], [0]], [[1], [1], [1]] ]
な感じで渡してやる必要があります
また、outputは [batch_size, max_len, cell.output_size] なshapeなので今回は上記xsと同じような形になります
output = [ [[1], [4], [7]], [[2], [5], [8]], [[3], [6], [9]], [[9], [9], [9]], [[8], [8], [8]] ]
このとき、[ [1], [4], [7] ] は xsの [ [1], [2], [3] ] に対応していて、xsの1を入力したときの出力が1, 2を入力したときの出力が4, 3を入力したときの出力が7 となっているっぽいです。
今回の例では最後の入力を与えたときの出力がほしいので、outputの右端の列をもらうために、output[:, -1, 0]としています
まとめると、下記な感じになります。
import random import tensorflow as tf sess = tf.Session() def generate_train_data(n): train_x = [] train_y = [] for i in range(n): xs = [random.random() * 10 for j in range(max_len)] ys = sum(xs) train_x.append(xs) train_y.append(ys) return train_x, train_y size = 1 rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=size, activation=tf.nn.leaky_relu) n_batch = 5 epochs = 4 max_len = 3 data_size = 1000 x = tf.placeholder(tf.float32, shape=[n_batch, max_len, size]) output, state = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32) y = tf.placeholder(tf.float32, shape=[n_batch]) loss = tf.reduce_mean(tf.square(y - output[:, -1, 0])) optimizer = tf.train.GradientDescentOptimizer(0.0001) train_step = optimizer.minimize(loss) init = tf.global_variables_initializer() sess.run(init) losses = [] n_batches = int(data_size/n_batch) x_train, y_train = generate_train_data(data_size) for epoch in range(epochs): for i in range(n_batches): min_ix = i * n_batch max_ix = (i+1) * n_batch xs = x_train[min_ix:max_ix] # shape [batch_size, max_len, input_size] # xs = [ [[1], [2], [3]], # [[4], [5], [6]], # [[7], [8], [9]], # [[0], [0], [0]], # [[1], [1], [1]] ] なイメージ xs = [ [[xxx] for xxx in xx] for xx in xs] ys = y_train[min_ix:max_ix] sess.run(train_step, feed_dict={x: xs, y: ys}) if i % 100 == 0: losses.append(sess.run(loss, feed_dict={x: xs, y: ys})) print('vars : ', sess.run(rnn_cell.variables[0])) # RNNのウェイトを表示してみる print('losses : ', losses)
結果は
vars : [[ 1.0746733 ] [ 0.92348057]] losses : [138.32841, 0.81912673, 0.57693189, 0.40475434, 0.29114446, 0.1994734, 0.14868297, 0.098647617]
な感じでした