Retrieval-based Deep Learning with TensorFlow v1.0+ and Python

By | June 11, 2017

Hi there. In this post I will cover code from a Github repo that I forked (detailed in this post) that trains a machine learning model based on IRC chat logs (the Ubuntu Dialog Corpus) to select the correct response out of a set of potential responses, given a context. The code was created last year with an older version of TensorFlow (TF), so if you try it with the latest version of TF, it will not work. I have made some tweaks and got the code to work for TensorFlow 1.0+ (Python3). A little about TensorFlow: it was a propriety project of Google, originally called DistBelief, that was refactored and open sourced in 2015. Shortly after that, a Python API was added. I won’t go into depth about the internals here because the original blog post on wildml.com has done a fine job of that already. This model basically works by looking at 1 million “chats” (github for how the data is generated), half real and half fake (positive and negative), and sort of learns the contexts and their meanings by focusing on co-occurrence, and by having the machinery to decide what is important and what is not. This example uses a Long Short-term Memory recurrent neural network (initially proposed 20 years ago) and is effective in cases where sequential order and time is important. Modeling data in this way is oftentimes used for information retrieval and chatbot applications.

You can view the changes made with Github. Let’s go over the updated  lines of code. The lines with “-” were taken out, the lines with “+” were substituted in.
udc_model.py


-    batch_size = targets.get_shape().as_list()[0]
+    # https://github.com/dennybritz/chatbot-retrieval/issues/32#issuecomment-265819109
+    if mode == tf.contrib.learn.ModeKeys.EVAL:
+        batch_size = targets.get_shape().as_list()[0]

This line would throw an error without the if statement. Googling the error led me straight to the github issues thread with the fix.


-          tf.concat(0, all_contexts),
-          tf.concat(0, all_context_lens),
-          tf.concat(0, all_utterances),
-          tf.concat(0, all_utterance_lens),
-          tf.concat(0, all_targets))
+          tf.concat(all_contexts,0),
+          tf.concat(all_context_lens,0),
+          tf.concat(all_utterances,0),
+          tf.concat(all_utterance_lens,0),
+          tf.concat(all_targets,0))

Going through a few Github issue threads quickly showed they reversed the parameter order for this method (something to do with how it interfaces with numpy or sklearn).


-      split_probs = tf.split(0, 10, probs)
-      shaped_probs = tf.concat(1, split_probs)
+      # fixed parameter order
+      split_probs = tf.split(probs, 10, 0)
+      shaped_probs = tf.concat(split_probs, 1)

Ditto here.


-      tf.histogram_summary("eval_correct_probs_hist", split_probs[0])
-      tf.scalar_summary("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
-      tf.histogram_summary("eval_incorrect_probs_hist", split_probs[1])
-      tf.scalar_summary("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))
+      tf.summary.histogram("eval_correct_probs_hist", split_probs[0])
+      tf.summary.scalar("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
+      tf.summary.histogram("eval_incorrect_probs_hist", split_probs[1])
+      tf.summary.scalar("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))

Here the name of the function has simply changed. Before, the function was “something_summary”. Now the function is just “something” that lives in the “summary” directory (which lives in “tf”). So they organized it more.
udc_predict.py


-  estimator._targets_info = tf.contrib.learn.estimators.tensor_signature.TensorSignature(tf.constant(0, shape=[1,1]))
+  # old line: estimator._targets_info = tf.contrib.learn.estimators.tensor_signature.TensorSignature(tf.constant(0, shape=[1,1]))
+  estimator._targets_info = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir)

Same thing as above here.


-    print("{}: {:g}".format(r, prob[0,0])) 
+    # print("{}: {:g}".format(r, prob[0,0]))
+    for k in list(prob):
+      print(r + ", " + str(k[0]))

If I remember correctly, the print statement wasn’t working because “prob” is a generator. One way to see what a generator contains is to apply the list() function to it. By the way, this is the part of the code where you’ll want to start customizing based on your specific test cases and criteria by editing “POTENTIAL_RESPONSES” so it contains test cases that will tell you if your model is accurate or not. To get your own data in the same form as this example so you can try it out, see this post where I have done just that.

Those are the major tweaks. I adjusted batch size to see if I could speed up the training since it was taking hours, full nights even, before I could get an idea of whether my adjustments had made any improvements (Google now rents TPUs for this). Long story short, my dataset was not as good as I had thought it would be (needed more data), and my results were little better than chance. I am working on creating a much larger dataset that could produce some interesting results, so do check back later on!

Also, if you don’t want to generate the dataset yourself, you can simply download it from Google Drive here. The Github I linked to earlier allows you to modify the parameters the data is created with, but the instructions aren’t exactly concise and clear.

Future direction and other examples

I believe that models that learn from contextual language texts will have more potential applications in the future, as well as the mountains of data needed to power them. There are offshoots of this type of modeling that are continuing to improve accuracy. Here is a paper that makes use of similar data and yields models that can reason about the language data they have been fed. These results are not specific to the English language – I was reading a paper (link missing) about researchers who trained their model on a massive amount of data from a Chinese forum, and were able to select stunningly appropriate responses to topic starters (though this is far less impressive than continuing a conversation for some time).

Hopefully this post provides a solid starting point for those who are new to either machine learning or Python (or both) along with updated working code. For a deeper understanding of the machine learning itself, I refer you back to the original wildml.com blog post, this open Stanford class, this open Oxford class on Github, and Andrej Karpathy’s “The Unreasonable Effectiveness of Recurrent Neural Networks”. That’s all I have for now, thanks for reading.

Facebooktwittergoogle_plusredditpinterestlinkedintumblr

Leave a Reply

Your email address will not be published. Required fields are marked *