Memory NN Example#

"""Example training a memory neural net on the bAbI dataset.

References Keras and is based off of https://keras.io/examples/babi_memnn/.
"""

from __future__ import print_function

from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Input, Activation, Dense, Permute, Dropout
from tensorflow.keras.layers import add, dot, concatenate
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.utils import get_file
from tensorflow.keras.preprocessing.sequence import pad_sequences

from filelock import FileLock
import os
import argparse
import tarfile
import numpy as np
import re

from ray import train, tune


def tokenize(sent):
    """Return the tokens of a sentence including punctuation.

    >>> tokenize("Bob dropped the apple. Where is the apple?")
    ["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
    """
    return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]


def parse_stories(lines, only_supporting=False):
    """Parse stories provided in the bAbi tasks format

    If only_supporting is true, only the sentences
    that support the answer are kept.
    """
    data = []
    story = []
    for line in lines:
        line = line.decode("utf-8").strip()
        nid, line = line.split(" ", 1)
        nid = int(nid)
        if nid == 1:
            story = []
        if "\t" in line:
            q, a, supporting = line.split("\t")
            q = tokenize(q)
            if only_supporting:
                # Only select the related substory
                supporting = map(int, supporting.split())
                substory = [story[i - 1] for i in supporting]
            else:
                # Provide all the substories
                substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append("")
        else:
            sent = tokenize(line)
            story.append(sent)
    return data


def get_stories(f, only_supporting=False, max_length=None):
    """Given a file name, read the file,
    retrieve the stories,
    and then convert the sentences into a single story.

    If max_length is supplied,
    any stories longer than max_length tokens will be discarded.
    """

    def flatten(data):
        return sum(data, [])

    data = parse_stories(f.readlines(), only_supporting=only_supporting)
    data = [
        (flatten(story), q, answer)
        for story, q, answer in data
        if not max_length or len(flatten(story)) < max_length
    ]
    return data


def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
    inputs, queries, answers = [], [], []
    for story, query, answer in data:
        inputs.append([word_idx[w] for w in story])
        queries.append([word_idx[w] for w in query])
        answers.append(word_idx[answer])
    return (
        pad_sequences(inputs, maxlen=story_maxlen),
        pad_sequences(queries, maxlen=query_maxlen),
        np.array(answers),
    )


def read_data(finish_fast=False):
    # Get the file
    try:
        path = get_file(
            "babi-tasks-v1-2.tar.gz",
            origin="https://s3.amazonaws.com/text-datasets/"
            "babi_tasks_1-20_v1-2.tar.gz",
        )
    except Exception:
        print(
            "Error downloading dataset, please download it manually:\n"
            "$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2"  # noqa: E501
            ".tar.gz\n"
            "$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz"  # noqa: E501
        )
        raise

    # Choose challenge
    challenges = {
        # QA1 with 10,000 samples
        "single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
        "single-supporting-fact_{}.txt",
        # QA2 with 10,000 samples
        "two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
        "two-supporting-facts_{}.txt",
    }
    challenge_type = "single_supporting_fact_10k"
    challenge = challenges[challenge_type]

    with tarfile.open(path) as tar:
        train_stories = get_stories(tar.extractfile(challenge.format("train")))
        test_stories = get_stories(tar.extractfile(challenge.format("test")))
    if finish_fast:
        train_stories = train_stories[:64]
        test_stories = test_stories[:64]
    return train_stories, test_stories


class MemNNModel(tune.Trainable):
    def build_model(self):
        """Helper method for creating the model"""
        vocab = set()
        for story, q, answer in self.train_stories + self.test_stories:
            vocab |= set(story + q + [answer])
        vocab = sorted(vocab)

        # Reserve 0 for masking via pad_sequences
        vocab_size = len(vocab) + 1
        story_maxlen = max(len(x) for x, _, _ in self.train_stories + self.test_stories)
        query_maxlen = max(len(x) for _, x, _ in self.train_stories + self.test_stories)

        word_idx = {c: i + 1 for i, c in enumerate(vocab)}
        self.inputs_train, self.queries_train, self.answers_train = vectorize_stories(
            word_idx, story_maxlen, query_maxlen, self.train_stories
        )
        self.inputs_test, self.queries_test, self.answers_test = vectorize_stories(
            word_idx, story_maxlen, query_maxlen, self.test_stories
        )

        # placeholders
        input_sequence = Input((story_maxlen,))
        question = Input((query_maxlen,))

        # encoders
        # embed the input sequence into a sequence of vectors
        input_encoder_m = Sequential()
        input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
        input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, story_maxlen, embedding_dim)

        # embed the input into a sequence of vectors of size query_maxlen
        input_encoder_c = Sequential()
        input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen))
        input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, story_maxlen, query_maxlen)

        # embed the question into a sequence of vectors
        question_encoder = Sequential()
        question_encoder.add(
            Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen)
        )
        question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, query_maxlen, embedding_dim)

        # encode input sequence and questions (which are indices)
        # to sequences of dense vectors
        input_encoded_m = input_encoder_m(input_sequence)
        input_encoded_c = input_encoder_c(input_sequence)
        question_encoded = question_encoder(question)

        # compute a "match" between the first input vector sequence
        # and the question vector sequence
        # shape: `(samples, story_maxlen, query_maxlen)`
        match = dot([input_encoded_m, question_encoded], axes=(2, 2))
        match = Activation("softmax")(match)

        # add the match matrix with the second input vector sequence
        response = add(
            [match, input_encoded_c]
        )  # (samples, story_maxlen, query_maxlen)
        response = Permute((2, 1))(response)  # (samples, query_maxlen, story_maxlen)

        # concatenate the match matrix with the question vector sequence
        answer = concatenate([response, question_encoded])

        # the original paper uses a matrix multiplication.
        # we choose to use a RNN instead.
        answer = LSTM(32)(answer)  # (samples, 32)

        # one regularization layer -- more would probably be needed.
        answer = Dropout(self.config.get("dropout", 0.3))(answer)
        answer = Dense(vocab_size)(answer)  # (samples, vocab_size)
        # we output a probability distribution over the vocabulary
        answer = Activation("softmax")(answer)

        # build the final model
        model = Model([input_sequence, question], answer)
        return model

    def setup(self, config):
        with FileLock(os.path.expanduser("~/.tune.lock")):
            self.train_stories, self.test_stories = read_data(config["finish_fast"])
        model = self.build_model()
        rmsprop = RMSprop(
            lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9)
        )
        model.compile(
            optimizer=rmsprop,
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"],
        )
        self.model = model

    def step(self):
        # train
        self.model.fit(
            [self.inputs_train, self.queries_train],
            self.answers_train,
            batch_size=self.config.get("batch_size", 32),
            epochs=self.config.get("epochs", 1),
            validation_data=([self.inputs_test, self.queries_test], self.answers_test),
            verbose=0,
        )
        _, accuracy = self.model.evaluate(
            [self.inputs_train, self.queries_train], self.answers_train, verbose=0
        )
        return {"mean_accuracy": accuracy}

    def save_checkpoint(self, checkpoint_dir):
        file_path = checkpoint_dir + "/model"
        self.model.save(file_path)

    def load_checkpoint(self, checkpoint_dir):
        # See https://stackoverflow.com/a/42763323
        del self.model
        file_path = checkpoint_dir + "/model"
        self.model = load_model(file_path)


if __name__ == "__main__":
    import ray
    from ray.tune.schedulers import PopulationBasedTraining

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        ray.init(num_cpus=2)

    perturbation_interval = 2
    pbt = PopulationBasedTraining(
        perturbation_interval=perturbation_interval,
        hyperparam_mutations={
            "dropout": lambda: np.random.uniform(0, 1),
            "lr": lambda: 10 ** np.random.randint(-10, 0),
            "rho": lambda: np.random.uniform(0, 1),
        },
    )

    tuner = tune.Tuner(
        MemNNModel,
        run_config=train.RunConfig(
            name="pbt_babi_memnn",
            stop={"training_iteration": 4 if args.smoke_test else 100},
            checkpoint_config=train.CheckpointConfig(
                checkpoint_frequency=perturbation_interval,
                checkpoint_score_attribute="mean_accuracy",
                num_to_keep=2,
            ),
        ),
        tune_config=tune.TuneConfig(
            scheduler=pbt,
            metric="mean_accuracy",
            mode="max",
            num_samples=2,
        ),
        param_space={
            "finish_fast": args.smoke_test,
            "batch_size": 32,
            "epochs": 1,
            "dropout": 0.3,
            "lr": 0.01,
            "rho": 0.9,
        },
    )
    tuner.fit()