Hello and welcome to part 5 of our neural network shenanigans series. Lately, we've been working on doing classification with a generative model. I'll class this a "classification generator model," since that's basically what we're after...a generative model that could take in a primer that is an image of a number, and hopefully produce its classification.
We've done 30,000 steps for this "classification generator" model, with a loss graph of:
Looks like this still has some learning to do, but let's see how it even does! To begin, let's just generate some output and see what it looks like. Let's edit the sample.py
file to output instead this time to out.txt rather than out.py, since it's not a python file anymore. Then let's run it:
python sample.py -n=10000 --prime=[
This should output some to console, but also to out.txt
. Let's see how that went:
[00000000000000000000000000000000] [00000000000000000000000000000]] ::[0000010000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000001111000000000000] [0000000000011111100000000000] [0000000000110001110000000000] [0000000000100000110000000000] [0000000000100000010000000000] [0000000000000000010000000000] [0000000000000000110000000000] [0000000000000001110000000000] [0000000000000111100000000000] [0000000000001111110000000000] [0000000000011111111110000000] [0000000000110000000010000000] [0000000000000000000010000000] [0000000000000000000010000000] [0000000010000000000010000000] [0000000001000000000110000000] [0000000001000000001100000000] [0000000001100000011000000000] [0000000000110001110000000000] [0000000000011111000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0001000000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000110000000000000] [0000000000000111000000000000] [0000000000000111000000000000] [0000000000000110000000000000] [0000000000000110000000000000] [0000000000000111000000000000] [0000000000000111000000000000] [0000000000000111000000000000] [0000000000000111000000000000] [0000000000001111100000000000] [0000000000001111000000000000] [0000000000001111000000000000] [0000000000001111000000000000] [0000000000001111000000000000] [0000000000000111000000000000] [0000000000000111000000000000] [0000000000000111100000000000] [0000000000000111100000000000] [0000000000000111100000000000] [0000000000000011100000000000] [0000000000000001000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0100000000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000001110000000] [0000000000000000011110000000] [0000000000000000011100000000] [0000000000000000111100000000] [0000000000000000111000000000] [0000000000000001111000000000] [0000000000000001111000000000] [0000000000000011110000000000] [0000000000000011110000000000] [0000000000000111100000000000] [0000000000000111000000000000] [0000000000001111000000000000] [0000000000001111000000000000] [0000000000111110000000000000] [0000000000111100000000000000] [0000000001111100000000000000] [0000000001111000000000000000] [0000000001110000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0100000000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000011000000000] [0000000000000011110000000000] [0000000011111111111000000000] [0000000011111100011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000000000000011000000000] [0000000011110000111000000000] [0000000011111111111000000000] [0000000011111111110000000000] [0000000011111111111000000000] [0000000001111111111000000000] [0000000000001111111000000000] [0000000000000111111000000000] [0000000000000111111000000000] [0000000000000111110000000000] [0000000000001111111000000000] [0000000000011111110000000000] [0000000000011101110000000000] [0000000000011111110000000000] [0000000000011111100000000000] [0000000000011111100000000000] [0000000000001111000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000000010]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000010001000100110000] [0000000000110111110000011000] [0000000001110111000000011000] [0000000001110010000000010000] [0000000011000000000000100000] [0000000011000000000001100000] [0000000110000000000011000000] [0000001100000000000011000000] [0000011000000000000111000000] [0000011000000000000110000000] [0000010000000000001100000000] [0000011000000000011000000000] [0000011000000000111000000000] [0000011000000001110000000000] [0000011000000011100000000000] [0000001100011111000000000000] [0000001111111100000000000000] [0000000111110000000000000000] [0000000011000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[1000000000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000111111000000000000] [0000000011111111110000000000] [0000000111000000110000000000] [0000001110000000010000000000] [0000001100000000010000000000] [0000001100000000010000000000] [0000001100000000010000000000] [0000000000000000010000000000] [0000000000000000110000000000] [0000000000000000110000000000] [0000000000000001100000000000] [0000000000000001100000000000] [0000000000000001100000000000] [0000000000000011000000000000] [0000000000000011000000000000] [0000000000000110000000000000] [0000000000000110110000000000] [0000000000000111100000000000] [0000000000000011000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000000100]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000001110000000000] [0000000000000111110000000000] [0000000000001100000000000000] [0000000000011000001000000000] [0000000000110000010000000000] [0000000000110000000000000000] [0000000000100000010000000000] [0000000001100000110000000000] [0000000001100000110000000000] [0000000001100001110000000000] [0000000001000111110000000000] [0000000001111111110000000000] [0000000001111000110000000000] [0000000000000000100000000000] [0000000000000000100000000000] [0000000000000000100000000000] [0000000000000001000000000000] [0000000000000001100000000000] [0000000000000001100000000000] [0000000000000000100000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000100000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000111111000000] [0000000000000001111111000000] [0000000000000011110011000000] [0000000000000110000011000000] [0000000000000100000011000000] [0000000000000000000110000000] [0000000000000000000110000000] [0000000000000000001100000000] [0000000000000000001100000000] [0000000000000000001100000000] [0000000000000000111000000000] [0000000001110011110000000000] [0000000111111111100000000000] [0000001110011111100000000000] [0000001100001111110000000000] [0000011000111100111000000000] [0000011111110000011000000000] [0000001110000000011000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0010000000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000001000000000000] [0000000000000011110000000000] [0000000000000111111000000000] [0000000000001110011000000000] [0000000000011100001000000000] [0000000000011000001000000000] [0000000000010000001000000000] [0000000000110000001000000000] [0000000000110000001000000000] [0000000000110000011000000000] [0000000000010001100000000000] [0000000000011111000000000000] [0000000000001111000000000000] [0000000000000110000000000000] [0000000000001110000000000000] [0000000000110011000000000000] [0000000001110011000000000000] [0000000001100011000000000000] [0000000011000111000000000000] [0000000011001110000000000000] [0000000011111110000000000000] [0000000001111100000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000000001]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000001111111000] [0000000000000111111111110000] [0000000000011111111000000000] [0000000001111000000000000000] [0000000011100000000000000000] [0000000111000000000000000000] [0000001111111100000000000000] [0000000111111111110000000000] [0000000000000001110000000000] [0000000000000000111000000000] [0000000000000000111000000000] [0000000000000000111000000000] [0000000000000000111000000000] [0000000000000001110000000000] [0000000000000011110000000000] [0000000000111111000000000000] [0000000000110000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000010000]:: [[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000001110000000000000] [0011100000011111111000000000] [0001110000000111111111110000] [0001111000000111110111100000] [0000111000000111000001110000] [0000011000000110000000110000] [0000011110000110000000110000] [0000001111000011000001110000] [0000000111100011000001100000] [0000000001100111000000000000] [0000000001111110000000000
That actually looks pretty darn good to me. Some of the classifications are wrong, but the generated numbers and such are quite good. Let's see if we can use this model now to just generate a classification based on an image's input. To do this, I am just going to quickly modify the sample.py
file, calling it testing_mnist.py
, which will now become:
from __future__ import print_function import tensorflow as tf import argparse import os from six.moves import cPickle from model import Model from six import text_type ## from tensorflow.examples.tutorials.mnist import input_data import numpy as np HM_TESTS = 100 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # mnist.train, mnist.test, mnist.validation batch_xs, batch_ys = mnist.validation.next_batch(HM_TESTS) def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--save_dir', type=str, default='save', help='model directory to store checkpointed models') parser.add_argument('-n', type=int, default=500, help='number of characters to sample') parser.add_argument('--prime', type=text_type, default=u' ', help='prime text') parser.add_argument('--sample', type=int, default=1, help='0 to use max at each timestep, 1 to sample at ' 'each timestep, 2 to sample on spaces') args = parser.parse_args() sample(args) def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) correct = 0 # here we modify. for i, data in enumerate(batch_xs): print(i) data = np.rint(batch_xs[i]).astype(int) label = np.rint(batch_ys[i]).astype(int) pixels = data.reshape((28,28)) str_label = str(label).replace(' ','') str_img = str(pixels).replace(' ','') sample_me = "{}\n::".format(str_img) #gen_to_data = "::{}::\n{}\n\n".format(str_label, str_img) classify_data = "{}\n::{}::\n\n".format(str_img, str_label) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') try: prediction = samp.decode('utf-8').split("::")[1].split('::')[0] if prediction == str_label: print("HOLY MOLY WE DID IT!!!!") correct += 1 print(str_label) except: pass print("Correct rate: {}".format(correct/HM_TESTS)) if __name__ == '__main__': main()
Main changes here are importing the dataset, and then iterating over that dataset. We use the string converted and simplified image pixel data of the image to be classified as the primer, and then we set the "n" value to just be 14, since that's all we should need here.
Then, we run the prediction and attempt to parse out anything that is between pairs of colons, since this would be our prediction. If that prediction matches the string converted and simplified label, we were successful! We're going to run this on 100 validation samples, and just see how we've done.
Running this, I get anything from 1/100 to 6/100 matches, which is unfortunate. Are we just going to accept this failure?! No! We demand better! Let's see what we're messing up on.
After the samp
is defined, let's just add in:
print(samp.decode('utf-8')) print() print()
Running this, I see things like:
[[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000101000000000] [0000000000000111110000000000] [0000000000001111000000000000] [0000000000011100000000000000] [0000000000111000000000000000] [0000000000110000011100000000] [0000000000110000111110000000] [0000000000111111110000000000] [0000000000111111000000000000] [0000000000011110000000000000] [0000000000111111100000000000] [0000000000111111110000000000] [0000000001110001111000000000] [0000000001110000011100000000] [0000000011100000011100000000] [0000000011100000011100000000] [0000000001100000111000000000] [0000000001111111110000000000] [0000000000111111100000000000] [0000000000011110000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0001100000000
Where the prediction is just simply not valid at all. I also saw:
[[0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000111000000000000] [0000000000111111110000000000] [0000000001111111111000000000] [0000000001100001111000000000] [0000000001100000011100000000] [0000000001100000011100000000] [0000000001100000011100000000] [0000000000110000011100000000] [0000000000111001111111000000] [0000000000111111111110000000] [0000000000011111111000000000] [0000000000111111000000000000] [0000000000111111000000000000] [0000000001110111100000000000] [0000000011100011110000000000] [0000000011100001111100000000] [0000000011100000111100000000] [0000000011100000111100000000] [0000000011111111111100000000] [0000000001111111110000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000] [0000000000000000000000000000]] ::[0000000010]:
Which is super unfortunate, because this is a correct prediction, just missed the last colon. So then I am wondering, hmm, we have two things we could fix for sure: 1, come up with a better way to parse out the prediction, and 2: come up with a way to validate if the prediction even was plausibly legit. If it wasn't we could try to regenerate it.
A prediction of ::[0000000001111
is not correct, we need to re-try.
With a quick re-working of the script:
sample_me = "{}\n::".format(str_img) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') print(samp.decode('utf-8')) print() print() sample_me = "{}\n::".format(str_img) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') print(samp.decode('utf-8'))
I just want to see if the samples will always be the same given the same input. They are interestingly *not.* This might make our lives easier then, since otherwise we'd have to slightly change the primer to hopefully get a legit prediction. We still might want to do that, but, for now, let's just try to keep generating a valid prediction.
def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) correct = 0 # here we modify. for i, data in enumerate(batch_xs): print(i) data = np.rint(batch_xs[i]).astype(int) label = np.rint(batch_ys[i]).astype(int) pixels = data.reshape((28,28)) str_label = str(label).replace(' ','') str_img = str(pixels).replace(' ','') #gen_to_data = "::{}::\n{}\n\n".format(str_label, str_img) classify_data = "{}\n::{}::\n\n".format(str_img, str_label) valid_prediction = False while not valid_prediction: try: sample_me = "{}\n::".format(str_img) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') prediction = samp.decode('utf-8').split("::")[1].split('::')[0] if len(prediction) == 12: valid_prediction = True except Exception as e: pass if prediction == str_label: print("HOLY MOLY WE DID IT!!!!") correct += 1 print(str_label) print() print() print() print("Correct rate: {}".format(correct/HM_TESTS))
With this, we'll just simply keep generating until a valid prediction is made, which could take a while. I ran this for a bit, then I decided maybe it would be best to switch around the primers too if we're having a hard time finding a valid generated class. I then modified to:
def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) correct = 0 # here we modify. for i, data in enumerate(batch_xs): print(i) data = np.rint(batch_xs[i]).astype(int) label = np.rint(batch_ys[i]).astype(int) pixels = data.reshape((28,28)) str_label = str(label).replace(' ','') str_img = str(pixels).replace(' ','') #gen_to_data = "::{}::\n{}\n\n".format(str_label, str_img) classify_data = "{}\n::{}::\n\n".format(str_img, str_label) valid_prediction = False sample_choices = ["{}\n::".format(str_img), "{}\n".format(str_img), "{}".format(str_img), "{}\n:".format(str_img)] while not valid_prediction: try: sample_me = random.choice(sample_choices) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') prediction = samp.decode('utf-8').split("::")[1].split('::')[0] if len(prediction) == 12: valid_prediction = True except Exception as e: pass if prediction == str_label: print("HOLY MOLY WE DID IT!!!!") correct += 1 print(str_label) print() print() print() else: print(samp.decode("utf-8")[-20:]) print("Correct rate: {}".format(correct/HM_TESTS))
With this, we will still infinitely search for a correct classification. I think it would likely be wise to limit the while statement to maybe... 10, 20, or *some number* of attempts before giving up and moving on.
def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) correct = 0 # here we modify. for i, data in enumerate(batch_xs): print(i) data = np.rint(batch_xs[i]).astype(int) label = np.rint(batch_ys[i]).astype(int) pixels = data.reshape((28,28)) str_label = str(label).replace(' ','') str_img = str(pixels).replace(' ','') #gen_to_data = "::{}::\n{}\n\n".format(str_label, str_img) classify_data = "{}\n::{}::\n\n".format(str_img, str_label) valid_prediction = False sample_choices = ["{}\n::".format(str_img), "{}\n".format(str_img), "{}".format(str_img), "{}\n:".format(str_img)] attempts = 0 while not valid_prediction: attempts += 1 if attempts > 20: print('Too many tries, moving along.') valid_prediction = True prediction = 'Darn...' try: sample_me = random.choice(sample_choices) samp = model.sample(sess, chars, vocab, 14, sample_me, args.sample).encode('utf-8') prediction = samp.decode('utf-8').split("::")[1].split('::')[0] if len(prediction) == 12: valid_prediction = True except Exception as e: pass if prediction == str_label: print("HOLY MOLY WE DID IT!!!!") correct += 1 print(str_label) print() print() print() else: print(samp.decode("utf-8")[-20:]) print("Correct rate: {}".format(correct/HM_TESTS))
With these changes though, I still found accuracy to be not any better, which is quite the bummer.
Looks like the generative model can easily generate numbers and classifications for its own numbers, but it is not very good at classifying other numbers. There are a few more things we can do here to try to boost our numbers, but this doesn't appear to be working at the moment. We also could re-visit this with a larger model, larger sequences, and so on. There are many tweaks we could make, it just takes time to train/test.
Well now I am curious if we can AT LEAST ask the model to draw us numbers. Let's see in the next tutorial.