#!./bin/python # -*-python-*- """ Minimal character-level Vanilla RNN model. Written by Andrej Karpathy (@karpathy) Major improvements by Robert Munafo (mrob.com) BSD License requirements: Python 2.7, Numpy 1.7.0 Original Source: gist.github.com/karpathy/d4dee566867f8291f086 Munafo Version: mrob.com/pub/comp/min-char-rnn.py.txt INSTALLATION These instructions are probably more than most people will need. I have a Mac whose standard Python is 2.6.1, but I got Python 2.7 and some associated packages using a Mac package manager "MacPorts" (www.macports.org). In addition to Python 2.7 itself, I needed to install packages for numpy, pip and virtualenv. All of these were installed into /opt/local/bin which is not normally part of my path. First I created a directory within which I'd be doing my RNN work. I already have a "~/proj" for projects, so I went there. cd ~/proj ; mkdir minchar-rnn ; cd minchar-rnn Like I said above I had to create a virtualenv to work with 2.7 Python modules: /opt/local/bin/virtualenv-2.7 --system-site-packages . However, if you already have Python 2.7 and a matching version of virtualenv, you can probably just do this: virtualenv --system-site-packages . Now "enter" the virtual environment: source ./bin/activate Tell it it upgrade its own pip and numpy: pip install --upgrade pip pip install numpy --upgrade Find out what version of numpy you got: echo 'import numpy ; print (numpy.__version__)' | python Get the tiny Shakespeare input data: wget --no-check-certificate 'https://github.com/karpathy/char-rnn/raw/master/data/tinyshakespeare/input.txt' -O in-shakespeare.txt If your numpy version (tested above) is at least 1.7.0, then you can run the script: ./min-char-rnn.py -input in-shakespeare.txt OPTIONS The command-line options are: -input in-mobydick.txt A filename (NOT an absolute or relative path) from which to read the input data set. The file should be in the current directory. Default value is 'input.txt' (without quotes) -rnn_size 200 hidden_size, the number of 'neurons' in the hidden layer. Default value is 100 -seq_length 80 seq_length, the number of loop steps in the two loops in lossFun. In general, the resulting model will exhibit a "short-term memory" that extends only for about this many characters. Default value is 25 -learning_rate Initial learning rate; default is 0.1 -learning_rate_decay Amount to adjust the learning rate after 'learning_rate_decay_after' epochs. -learning_rate_decay_after After this many epochs, begin diminishing the learning rate. -seed 143 A seed for the random number generator. If this option is given along with -overwrite to ignore any saved data, then you can run the program twice and get the same results. (All other options and the input file must also be the same). This can be useful for debugging. -overwrite Ignore any existing "dat-foo.npz" saved checkpoint data file. A summary of the options is given if you use the option '-h' or '--help', or if you give any unrecognised option. TRAINING While training, it will periodically sample the model and print out brief bits of text. It runs twice as long for each sample it prints: Samples are printed at 100, 200, 400, 800, 1600, 3200, ... iterations through the main loop. It will also log larger samples to a file whose name will start with "out-" and is derived from the input filename. For example if the input filename is "in-mobydick.txt", this log file will be "out-mobydick.txt". Every 4096 times through the main loop, it also write an even larger sample to a file starting with "smp-". This file does *not* keep growing, but is overwritten each time a sample is written. Also every 4096 times through the main loop, the entire model will be saved in a file starting with "dat-" and ending in ".npz". This is a checkpoint and will be automatically loaded when you start the script again using the same input filename. REVISION HISTORY 20151208 First version. Improve output formatting. Generate 'cur-sample.txt'. Write longer samples to an output file. Take command-line parameters for input filename and hidden_size. Show number of parameters at startup. Take 3rd argument for seq_length 20151209 Save and reload from a NPZ file. Uncomment the gradCheck. Add a bunch of instructions. 20151210 xrange->range, parenthesise print arguments (for Py3.x compatibility). Begin 4-space indenting. Set p differently at the start of each new epoch, to avoid one source of overfitting. Add sample_temperature 20151211 Add softmax function and sample_temperature option. Add learning_rate_decay. Convert a few more print statements. 20151212 Use argparse. "-h" now explains all options. Replace softmax function with two better versions. Show Python version when starting. Temperature can be given as an option. Add -seed option. Add -overwrite option. Update the long help (in this block comment, above) to describe new option format. """ import numpy as np import sys import re import os.path import argparse parser = argparse.ArgumentParser(prog='min-char-rnn', description= 'Minimal 1-layer Recurrent Neural Network (RNN) for text streams') parser.add_argument('-input', default="input.txt", help='name of simple plain text file in current directory') parser.add_argument('-rnn_size', type=int, default=100, help='size of hidden layer of neurons') # num_layers is stuck at 1 %%% but I hope to implement 2 # model will always be rnn (no lstm or gru) parser.add_argument('-learning_rate', type=float, default=1e-1, help='initial learning rate') parser.add_argument('-learning_rate_decay', type=float, default=0.97, help='learning rate decay') parser.add_argument('-learning_rate_decay_after', type=int, default=10, help='in number of epochs, when to start decaying the learning rate') # decay_rate (alpha) is only used by ELU, not sure why char-rnn/train.lua even has it # dropout will go here parser.add_argument('-seq_length', type=int, default=25, help='number of timesteps to unroll for') # batch_size # max_epochs # grad_clip # train_frac # val_frac parser.add_argument('-seed', type=int, default=0, help='numpy manual random number generator seed') # - - - Exta options added by Robert Munafo parser.add_argument('-temperature', type=float, default=1.0, help='Softmax temperature for sampling (does not affect training)', dest='o_temp') parser.add_argument('-overwrite', action='store_true', help='overwrite existing data save file, if any', dest='o_overwrite') args = parser.parse_args() print("Python version: {}.{}".format(sys.version_info[0], sys.version_info[1])) print("input file: {}".format(args.input)) print("rnn_size = {}".format(args.rnn_size)) print("learning_rate = {:f}".format(args.learning_rate)) print("learning_rate_decay = {:f}".format(args.learning_rate_decay)) print("learning_rate_decay_after = {}".format(args.learning_rate_decay_after)) print("seq_length = {}".format(args.seq_length)) print("seed = {}".format(args.seed)) print("temperature = {}".format(args.o_temp)) print("overwrite = {}".format(args.o_overwrite)) # Adjustable hyperparameters. NOTE: These will be ignored if there is a # saved-data file ("checkpoint") with a matching filename. hidden_size = args.rnn_size # size of hidden layer of neurons seq_length = args.seq_length # number of steps to unroll the RNN for learning_rate = args.learning_rate learning_rate_decay_after = args.learning_rate_decay_after # in number of epochs, when to start decaying the learning rate # sample_temperature does not affect training; it only affects the # samples that are printed to the screen, to out-filename.txt and to # smp-filename.txt It affects the softmax function call in the # sample() function, but now the softmax in lossFunc. # The default is 1.0 and selects output at random with the normal # softmax rates. # Lower values (like 0.8) are "more conservative" and tend to repeat # text that has been seen during training. # Higher values (like 1.25) are "more risky" and tend to contain more # spelling errors, invented words/phrases, etc. # Anything lower than 0.5 sample_temperature = args.o_temp inpath = args.input seq_length = args.seq_length if re.match(r'^in-', inpath): outpath = re.sub(r'^in-', 'out-', inpath) smp_path = re.sub(r'^in-', 'smp-', inpath) dat_path = re.sub(r'^in-', 'dat-', inpath) else: outpath = 'out-' + inpath smp_path = 'smp-' + inpath dat_path = 'dat-' + inpath dat_path = re.sub(r'\.txt$', '.npz', dat_path) dat_path = re.sub(r'\.rhtf$', '.npz', dat_path) print ("inpath: '{inpath}' outpath: '{outpath}' dat_path: '{dat_path}'" .format(inpath=inpath, outpath=outpath, dat_path=dat_path)) # data I/O data = open(inpath, 'r').read() # should be simple plain text file chars = list(set(data)) data_size, vocab_size = len(data), len(chars) print ("data has {data_size} characters, {vocab_size} unique." .format(data_size=data_size, vocab_size=vocab_size)) char_to_ix = { ch:i for i,ch in enumerate(chars) } ix_to_char = { i:ch for i,ch in enumerate(chars) } # Set up random matrices for the model parameters. NOTE: These will be # replaced (or even resized) below if there is a data file. if args.seed > 0: np.random.seed(args.seed) W_xh = np.random.randn(hidden_size, vocab_size)*0.01 # input to hidden W_hh = np.random.randn(hidden_size, hidden_size)*0.01 # hidden to hidden W_hy = np.random.randn(vocab_size, hidden_size)*0.01 # hidden to output bh = np.zeros((hidden_size, 1)) # hidden bias by = np.zeros((vocab_size, 1)) # output bias def softmax(w, temp=1.0): """ standard softmax function (en.wikipedia.org/wiki/Softmax_function) given an input vector w, calculate a result vector y. For each j in [1..len(w)], y_j = exp(w_j)/SUM, where SUM is the sum of exp(w_k) for all k in [1..len(w)] For example, if w = [0.2, 0.3, 0.5 ] then np.exp(w) = [1.22, 1.35, 1.65] SIGMA = np.sum(np.exp(w) = 4.22 and softmax(w) = [0.29, 0.32, 0.39] The second parameter 'temp' is optional and is the "Softmax temperature" """ e = np.exp(np.array(w)/temp) dist = e / np.sum(e) return dist def softmax_1(x): """ optimised version for temp=1.0 source: gist.github.com/stober/1946926 """ e_x = np.exp(x - np.max(x)) out = e_x / np.sum(e_x) return out def lossFun(inputs, targets, hprev): """ inputs,targets are both list of integers. hprev is Hx1 array of initial hidden state returns the loss, gradients on model parameters, and last hidden state """ xs, hs, ys, ps = {}, {}, {}, {} hs[-1] = np.copy(hprev) loss = 0 # forward pass for t in range(len(inputs)): xs[t] = np.zeros((vocab_size,1)) # encode in 1-of-k representation xs[t][inputs[t]] = 1 # update the hidden state hs[t] = np.tanh(np.dot(W_xh, xs[t]) + np.dot(W_hh, hs[t-1]) + bh) # compute the output vector (unnormalized log prob'ys for next chars) ys[t] = np.dot(W_hy, hs[t]) + by ps[t] = softmax_1(ys[t]) # prob'ys for next chars loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss) # backward pass: compute gradients going backwards dW_xh = np.zeros_like(W_xh) dW_hh, dW_hy = np.zeros_like(W_hh), np.zeros_like(W_hy) dbh, dby = np.zeros_like(bh), np.zeros_like(by) dhnext = np.zeros_like(hs[0]) for t in reversed(range(len(inputs))): dy = np.copy(ps[t]) dy[targets[t]] -= 1 # backprop into y dW_hy += np.dot(dy, hs[t].T) dby += dy dh = np.dot(W_hy.T, dy) + dhnext # backprop into h dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity dbh += dhraw dW_xh += np.dot(dhraw, xs[t].T) dW_hh += np.dot(dhraw, hs[t-1].T) dhnext = np.dot(W_hh.T, dhraw) for dparam in [dW_xh, dW_hh, dW_hy, dbh, dby]: np.clip(dparam, -5, 5, out=dparam) # mitigate exploding gradients return loss, dW_xh, dW_hh, dW_hy, dbh, dby, hs[len(inputs)-1] def sample(h, seed_ix, n): """ sample a sequence of integers from the model h is memory state, seed_ix is seed letter for first time step """ x = np.zeros((vocab_size, 1)) ; x[seed_ix] = 1 ixes = [] for t in range(n): h = np.tanh(np.dot(W_xh, x) + np.dot(W_hh, h) + bh) y = np.dot(W_hy, h) + by # For each possible next letter, the probability that we'll choose # that letter is proportional to the output value we just computed; # and the sum of all probs must be 1.0. This is the "Softmax" # calculation. p = softmax(y, sample_temperature) # Choose the next letter at random, but weighted by the probabilites # we just computed. ix = np.random.choice(range(vocab_size), p=p.ravel()) # Regenerate the input vector using the letter we just chose. x = np.zeros((vocab_size, 1)) ; x[ix] = 1 ixes.append(ix) return ixes # gradient checking from random import uniform def gradCheck(inputs, target, hprev): global W_xh, W_hh, W_hy, bh, by num_checks, delta = 10, 1e-5 _, dW_xh, dW_hh, dW_hy, dbh, dby, _ = lossFun(inputs, targets, hprev) for param,dparam,name in zip([ W_xh, W_hh, W_hy, bh, by], [dW_xh, dW_hh, dW_hy, dbh, dby], ['W_xh', 'W_hh', 'W_hy', 'bh', 'by']): s0 = dparam.shape s1 = param.shape assert s0 == s1, ("Error dims dont match: {s0} and {s1}." .format(s0=s0, s1=s1)) print (name) for i in range(num_checks): ri = int(uniform(0,param.size)) # evaluate cost at [x + delta] and [x - delta] old_val = param.flat[ri] param.flat[ri] = old_val + delta cg0, _, _, _, _, _, _ = lossFun(inputs, targets, hprev) param.flat[ri] = old_val - delta cg1, _, _, _, _, _, _ = lossFun(inputs, targets, hprev) param.flat[ri] = old_val # reset old value for this parameter # fetch both numerical and analytic gradient grad_analytic = dparam.flat[ri] grad_numerical = (cg0 - cg1) / ( 2 * delta ) rel_error = abs(grad_numerical) + abs(grad_analytic) if (rel_error != 0): rel_error = abs(grad_analytic - grad_numerical) / rel_error print ("{gn:8.5f}, {ga:8.5f} => {re:11.5e} " .format(gn=grad_numerical, ga=grad_analytic, re=rel_error)) # rel_error should be on order of 1e-7 or less def outblock(epoch, i_iter, f_loss, i_nchr): """ generate a block of text suitable for printing to the screen or saving in one of our log files. """ sample_ix = sample(hprev, inputs[0], i_nchr) txt = ''.join(ix_to_char[ix] for ix in sample_ix) outpt = '%s epoch %d iter %d, loss: %7.3f\n%s' % ( '--------------------------------------', epoch, i_iter, f_loss, txt ) outpt = ("{x1} epoch {e} iter {i}, loss={l:7.3f} temp={t:5.3f}\n{txt}" .format(x1='--------------------------------------', e=epoch, i=i_iter, l=float(f_loss), t=sample_temperature, txt=txt)) return outpt n, p = 0, 0 ni = 0 epoch = 0; p2 = 1 mW_xh, mW_hh, mW_hy = np.zeros_like(W_xh), np.zeros_like(W_hh), np.zeros_like(W_hy) mbh, mby = np.zeros_like(bh), np.zeros_like(by) # memory variables for Adagrad smooth_loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0 if os.path.isfile(dat_path): if args.o_overwrite: # We'll ignore any extant file print("Ignoring existing '{}'".format(dat_path)) else: # We have a data file from last time loaded = np.load(dat_path) hidden_size = loaded['hidden_size'] seq_length = loaded['seq_length'] W_xh = loaded['W_xh'] W_hh = loaded['W_hh'] W_hy = loaded['W_hy'] bh = loaded['bh'] bh = loaded['bh'] n = loaded['n'] p = loaded['p'] ni = loaded['ni'] epoch = loaded['epoch'] p2 = loaded['p2'] mW_xh = loaded['mW_xh'] mW_hh = loaded['mW_hh'] mW_hy = loaded['mW_hy'] mbh = loaded['mbh'] mby = loaded['mby'] smooth_loss = loaded['smooth_loss'] hprev = loaded['hprev'] print ("reloaded state from '%s', epoch=%d, n=%d, W_hh is %d x %d" % (dat_path, epoch, n, len(W_hh), len(W_hh[0]))) # Recompute derived values learning_rate = args.learning_rate if epoch > learning_rate_decay_after: learning_rate = (args.learning_rate * pow(args.learning_rate_decay, epoch-learning_rate_decay_after)) outpt = ("==== start ==== epoch={e}, n={n}, seq_length={sl}, rnn_size={rs}, {p:8.2e} parameters" .format(e=epoch, n=n, sl=seq_length, rs=hidden_size, p=(4.0 * hidden_size * hidden_size))) print (outpt) f = open(outpath, 'a') f.write(outpt + '\n') f.close() while True: # prepare inputs (we're sweeping from left to right in steps seq_length long) if p+seq_length+1 >= len(data) or n == 0: # we're near the end and don't have enough data left to do the next batch, # or it's our first time hprev = np.zeros((hidden_size,1)) # reset RNN memory epoch += 1 p = 0 # p = epoch % seq_length # reset to a new position near the start if epoch > learning_rate_decay_after: learning_rate = (args.learning_rate * pow(args.learning_rate_decay, epoch-learning_rate_decay_after)) if learning_rate > 0.05: print ( "Starting epoch {e}, decay learning rate by {lrd:4.2f} to {lr:5.3f}" .format(e=epoch, lrd=args.learning_rate_decay, lr=learning_rate)) inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]] targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]] # sample from the model now and then if n > 0 and n % 100 == 0: ni = ni + 1 if ni >= p2: outpt = outblock(epoch, n, smooth_loss, 200) print (outpt) # Make a bigger sample for our output file outpt = outblock(epoch, n, smooth_loss, 1024) f = open(outpath, 'a') f.write(outpt + '\n') f.close() # We'll do this again at the next power of 2 p2 = p2 * 2 if n > 0 and n % 4096 == 0: # Write an even bigger, volatile sample. This allows the user to get # more stuff during the middle and later stages of evolution, when the # normal output file is not being updated as often. outpt = outblock(epoch, n, smooth_loss, 4096) f = open(smp_path, 'w') f.write(outpt) f.close() # forward seq_length characters through the net and fetch gradient loss, dW_xh, dW_hh, dW_hy, dbh, dby, hprev = lossFun(inputs, targets, hprev) smooth_loss = smooth_loss * 0.999 + loss * 0.001 # perform parameter update with Adagrad # %%% There should be a 'validation mode' where we just look at the value # of 'loss' and don't adjust the model parameters. Use this to detect # overfitting. for param, dparam, mem in zip([ W_xh, W_hh, W_hy, bh, by], [dW_xh, dW_hh, dW_hy, dbh, dby], [mW_xh, mW_hh, mW_hy, mbh, mby]): mem += dparam * dparam param += -learning_rate * dparam / np.sqrt(mem + 1e-8) # adagrad update p += seq_length # move data pointer n += 1 # iteration counter if n > 0 and n % 4096 == 0: # Snapshot to file # %%% We should also generate checkpoints, but we need to do model # validation first. np.savez_compressed(dat_path, hidden_size=hidden_size, seq_length=seq_length, W_xh=W_xh, W_hh=W_hh, W_hy=W_hy, bh=bh, by=by, n=n, p=p, ni=ni, epoch=epoch, p2=p2, mW_xh=mW_xh, mW_hh=mW_hh, mW_hy=mW_hy, mbh=mbh, mby=mby, smooth_loss=smooth_loss, hprev=hprev) print ("saved state to '%s', epoch=%d, n=%d, W_hh is %d x %d" % (dat_path, epoch, n, len(W_hh), len(W_hh[0]))) # %%% We want to adapt this to be more useful, perhaps the gradient # stats are a symptom of overfitting. # print ("Gradient check:\n") # gradCheck(inputs, targets, hprev)