diff --git a/.gitignore b/.gitignore index e038e74..74268d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,64 +1,8 @@ -# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class -# C extensions -*.so +.idea/ +.env/ -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -#Ipython Notebook -.ipynb_checkpoints -/.project -/.pydevproject +data/ diff --git a/README.md b/README.md index edabccd..c094c0f 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,5 @@ a Deep Learning based Speller See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm - -Additional details: - -I used this AMI to train the system: -https://aws.amazon.com/marketplace/pp/B06VSPXKDX -On a p2.xlarge instance (currently at $0.9 per Hour) +## Python 3.5 +updated code to run in Python 3.5 diff --git a/keras_spell.py b/keras_spell.py index 7442d9d..4e0d7ee 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -1,65 +1,44 @@ -# encoding: utf-8 -''' -Created on Nov 26, 2015 - -@author: tal - -Based in part on: -Learn math - https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py - -See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm -''' - -from __future__ import print_function, division, unicode_literals - import os import errno from collections import Counter -from hashlib import sha256 import re import json import itertools import logging import requests import numpy as np -from numpy.random import choice as random_choice, randint as random_randint, shuffle as random_shuffle, seed as random_seed, rand -from numpy import zeros as np_zeros # pylint:disable=no-name-in-module +import random from keras.models import Sequential, load_model from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent from keras.callbacks import Callback -# Set a logger for the module -LOGGER = logging.getLogger(__name__) # Every log will use the module name -LOGGER.addHandler(logging.StreamHandler()) -LOGGER.setLevel(logging.DEBUG) +# setup console logging +logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) + +random.seed(123) # Reproducibility -random_seed(123) # Reproducibility class Configuration(object): """Dump stuff here""" + CONFIG = Configuration() -#pylint:disable=attribute-defined-outside-init -# Parameters for the model: CONFIG.input_layers = 2 CONFIG.output_layers = 2 CONFIG.amount_of_dropout = 0.2 CONFIG.hidden_size = 500 -CONFIG.initialization = "he_normal" # : Gaussian initialization scaled by fan-in (He et al., 2014) +CONFIG.initialization = "he_normal" # : Gaussian initialization scaled by fan-in (He et al., 2014) CONFIG.number_of_chars = 100 CONFIG.max_input_len = 60 CONFIG.inverted = True # parameters for the training: -CONFIG.batch_size = 100 # As the model changes in size, play with the batch size to best fit the process in memory -CONFIG.epochs = 500 # due to mini-epochs. -CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. +CONFIG.batch_size = 100 # As the model changes in size, play with the batch size to best fit the process in memory +CONFIG.epochs = 500 # due to mini-epochs. +CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. CONFIG.validation_steps = 10 CONFIG.number_of_iterations = 10 -#pylint:enable=attribute-defined-outside-init - -DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest() # Parameters for the dataset MIN_INPUT_LEN = 5 @@ -67,8 +46,8 @@ class Configuration(object): CHARS = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .") PADDING = "☕" -DATA_FILES_PATH = "~/Downloads/data" -DATA_FILES_FULL_PATH = os.path.expanduser(DATA_FILES_PATH) +DATA_FILES_PATH = "data" +DATA_FILES_FULL_PATH = os.path.join(os.path.dirname(__file__), DATA_FILES_PATH) DATA_FILES_URL = "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz" NEWS_FILE_NAME_COMPRESSED = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.shuffled.gz") # 1.1 GB NEWS_FILE_NAME_ENGLISH = "news.2013.en.shuffled" @@ -84,9 +63,9 @@ class Configuration(object): # Some cleanup: NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) # match all whitespace except newlines RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) -RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(unichr(768), unichr(769), unichr(832), - unichr(833), unichr(2387), unichr(5151), - unichr(5152), unichr(65344), unichr(8242)), +RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(chr(768), chr(769), chr(832), + chr(833), chr(2387), chr(5151), + chr(5152), chr(65344), chr(8242)), re.UNICODE) RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) @@ -98,68 +77,74 @@ class Configuration(object): def download_the_news_data(): """Download the news data""" - LOGGER.info("Downloading") - try: - os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise - with open(NEWS_FILE_NAME_COMPRESSED, "wb") as output_file: - response = requests.get(DATA_FILES_URL, stream=True) - total_length = response.headers.get('content-length') - downloaded = percentage = 0 - print("»"*100) - total_length = int(total_length) - for data in response.iter_content(chunk_size=4096): - downloaded += len(data) - output_file.write(data) - new_percentage = 100 * downloaded // total_length - if new_percentage > percentage: - print("☑", end="") - percentage = new_percentage - print() + if not os.path.isfile(NEWS_FILE_NAME_COMPRESSED): + logging.info("download_the_news_data: Downloading " + DATA_FILES_URL) + try: + os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + with open(NEWS_FILE_NAME_COMPRESSED, "wb") as output_file: + response = requests.get(DATA_FILES_URL, stream=True) + total_length = response.headers.get('content-length') + downloaded = percentage = 0 + print("»"*100) + total_length = int(total_length) + for data in response.iter_content(chunk_size=4096): + downloaded += len(data) + output_file.write(data) + new_percentage = 100 * downloaded // total_length + if new_percentage > percentage: + print("☑", end="") + percentage = new_percentage + print() def uncompress_data(): """Uncompress the data files""" import gzip - with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: - with open(NEWS_FILE_NAME_COMPRESSED[:-3], 'wb') as outfile: - outfile.write(compressed_file.read()) + out_filename = NEWS_FILE_NAME_COMPRESSED[:-3] + if not os.path.isfile(out_filename): + logging.info("uncompress_data: " + out_filename) + with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: + with open(out_filename, 'wb') as outfile: + outfile.write(compressed_file.read()) + def add_noise_to_string(a_string, amount_of_noise): """Add some artificial spelling mistakes to the string""" - if rand() < amount_of_noise * len(a_string): + if random.random() < amount_of_noise * len(a_string): # Replace a character with a random character - random_char_position = random_randint(len(a_string)) - a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position + 1:] - if rand() < amount_of_noise * len(a_string): + random_char_position = random.randint(0, len(a_string)) + a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position + 1:] + if random.random() < amount_of_noise * len(a_string): # Delete a character - random_char_position = random_randint(len(a_string)) + random_char_position = random.randint(0, len(a_string)) a_string = a_string[:random_char_position] + a_string[random_char_position + 1:] - if len(a_string) < CONFIG.max_input_len and rand() < amount_of_noise * len(a_string): + if len(a_string) < CONFIG.max_input_len and random.random() < amount_of_noise * len(a_string): # Add a random character - random_char_position = random_randint(len(a_string)) - a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position:] - if rand() < amount_of_noise * len(a_string): + random_char_position = random.randint(0, len(a_string)) + a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position:] + if random.random() < amount_of_noise * len(a_string): # Transpose 2 characters - random_char_position = random_randint(len(a_string) - 1) + random_char_position = random.randint(0, len(a_string) - 2) a_string = (a_string[:random_char_position] + a_string[random_char_position + 1] + a_string[random_char_position] + a_string[random_char_position + 2:]) return a_string + def _vectorize(questions, answers, ctable): """Vectorize the data as numpy arrays""" len_of_questions = len(questions) - X = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(questions)): + X = np.zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) + for i in range(len(questions)): sentence = questions.pop() for j, c in enumerate(sentence): try: X[i, j, ctable.char_indices[c]] = 1 except KeyError: pass # Padding - y = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(answers)): + y = np.zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) + for i in range(len(answers)): sentence = answers.pop() for j, c in enumerate(sentence): try: @@ -168,6 +153,7 @@ def _vectorize(questions, answers, ctable): pass # Padding return X, y + def slice_X(X, start=None, stop=None): """This takes an array-like, or a list of array-likes, and outputs: @@ -196,6 +182,7 @@ def slice_X(X, start=None, stop=None): else: return X[start:stop] + def vectorize(questions, answers, chars=None): """Vectorize the questions and expected answers""" print('Vectorization...') @@ -266,7 +253,7 @@ def size(self): def encode(self, C, maxlen): """Encode as one-hot""" - X = np_zeros((maxlen, len(self.chars)), dtype=np.bool) # pylint:disable=no-member + X = np.zeros((maxlen, len(self.chars)), dtype=np.bool) # pylint:disable=no-member for i, c in enumerate(C): X[i, self.char_indices[c]] = 1 return X @@ -277,6 +264,7 @@ def decode(self, X, calc_argmax=True): X = X.argmax(axis=-1) return ''.join(self.indices_char[x] for x in X if x) + def generator(file_name): """Returns a tuple (inputs, targets) All arrays should contain the same number of samples. @@ -288,9 +276,9 @@ def generator(file_name): while True: with open(file_name) as answers: for answer in answers: - batch_of_answers.append(answer.strip().decode('utf-8')) + batch_of_answers.append(answer.strip()) if len(batch_of_answers) == CONFIG.batch_size: - random_shuffle(batch_of_answers) + np.random.shuffle(batch_of_answers) batch_of_questions = [] for answer_index, answer in enumerate(batch_of_answers): question, answer = generate_question(answer) @@ -302,23 +290,27 @@ def generator(file_name): yield X, y batch_of_answers = [] + def print_random_predictions(model, ctable, X_val, y_val): """Select 10 samples from the validation set at random so we can visualize errors""" print() - for _ in range(10): - ind = random_randint(0, len(X_val)) - rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])] # pylint:disable=no-member - preds = model.predict_classes(rowX, verbose=0) - q = ctable.decode(rowX[0]) - correct = ctable.decode(rowy[0]) - guess = ctable.decode(preds[0], calc_argmax=False) - if CONFIG.inverted: - print('Q', q[::-1]) # inverted back! - else: - print('Q', q) - print('A', correct) - print(Colors.green + '☑' + Colors.close if correct == guess else Colors.red + '☒' + Colors.close, guess) - print('---') + try: + for _ in range(10): + ind = random.randint(0, len(X_val)) + rowX, rowy = X_val[np.asarray([ind])], y_val[np.asarray([ind])] # pylint:disable=no-member + preds = model.predict_classes(rowX, verbose=0) + q = ctable.decode(rowX[0]) + correct = ctable.decode(rowy[0]) + guess = ctable.decode(preds[0], calc_argmax=False) + if CONFIG.inverted: + print('Q', q[::-1]) # inverted back! + else: + print('Q', q) + print('A', correct) + print(Colors.green + '☑' + Colors.close if correct == guess else Colors.red + '☒' + Colors.close, guess) + print('---') + except Exception: + pass print() @@ -332,9 +324,11 @@ def on_epoch_end(self, epoch, logs=None): print_random_predictions(self.model, ctable, X_val, y_val) self.model.save(SAVED_MODEL_FILE_NAME.format(epoch)) + ON_EPOCH_END_CALLBACK = OnEpochEndCallback() -def itarative_train(model): + +def iterative_train(model): """ Iterative training of the model - To allow for finite RAM... @@ -359,6 +353,7 @@ def iterate_training(model, X_train, y_train, X_val, y_val, ctable): validation_data=(X_val, y_val)) print_random_predictions(model, ctable, X_val, y_val) + def clean_text(text): """Clean the text - remove unwanted chars, fold punctuation etc.""" result = NORMALIZE_WHITESPACE_REGEX.sub(' ', text.strip()) @@ -369,31 +364,33 @@ def clean_text(text): result = RE_BASIC_CLEANER.sub('', result) return result + def preprocesses_data_clean(): """Pre-process the data - step 1 - cleanup""" - with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data: - for line in open(NEWS_FILE_NAME): - decoded_line = line.decode('utf-8') - cleaned_line = clean_text(decoded_line) - encoded_line = cleaned_line.encode("utf-8") - clean_data.write(encoded_line + b"\n") + if not os.path.isfile(NEWS_FILE_NAME_CLEAN): + logging.info("preprocesses_data_clean") + with open(NEWS_FILE_NAME_CLEAN, "wt") as clean_data: + for line in open(NEWS_FILE_NAME): + clean_data.write(clean_text(line) + "\n") + def preprocesses_data_analyze_chars(): """Pre-process the data - step 2 - analyze the characters""" - counter = Counter() - LOGGER.info("Reading data:") - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - counter.update(decoded_line) -# data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') -# LOGGER.info("Read.\nCounting characters:") -# counter = Counter(data.replace("\n", "")) - LOGGER.info("Done.\nWriting to file:") - with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file: - output_file.write(json.dumps(counter)) - most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} - LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) - LOGGER.info("".join(sorted(most_popular_chars))) + if not os.path.isfile(CHAR_FREQUENCY_FILE_NAME): + counter = Counter() + logging.info("preprocesses_data_analyze_chars: Reading data:") + for line in open(NEWS_FILE_NAME_CLEAN): + counter.update(line) + # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') + # logging.info("Read.\nCounting characters:") + # counter = Counter(data.replace("\n", "")) + logging.info("Done.\nWriting to file:") + with open(CHAR_FREQUENCY_FILE_NAME, 'wt') as output_file: + output_file.write(json.dumps(counter)) + most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} + logging.info("The top %s chars are:", CONFIG.number_of_chars) + logging.info("".join(sorted(most_popular_chars))) + def read_top_chars(): """Read the top chars we saved to file""" @@ -402,24 +399,28 @@ def read_top_chars(): most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} return most_popular_chars + def preprocesses_data_filter(): """Pre-process the data - step 3 - filter only sentences with the right chars""" - most_popular_chars = read_top_chars() - LOGGER.info("Reading and filtering data:") - with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file: - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - if decoded_line and not bool(set(decoded_line) - most_popular_chars): - output_file.write(line) - LOGGER.info("Done.") + if not os.path.isfile(NEWS_FILE_NAME_FILTERED): + most_popular_chars = read_top_chars() + logging.info("preprocesses_data_filter: Reading and filtering data:") + with open(NEWS_FILE_NAME_FILTERED, "wt") as output_file: + for line in open(NEWS_FILE_NAME_CLEAN): + decoded_line = line + if decoded_line and not bool(set(decoded_line) - most_popular_chars): + output_file.write(line) + logging.info("Done.") + def read_filtered_data(): """Read the filtered data corpus""" - LOGGER.info("Reading filtered data:") - lines = open(NEWS_FILE_NAME_FILTERED).read().decode('utf-8').split("\n") - LOGGER.info("Read filtered data - %s lines", len(lines)) + logging.info("Reading filtered data:") + lines = open(NEWS_FILE_NAME_FILTERED).read().split("\n") + logging.info("Read filtered data - %s lines", len(lines)) return lines + def preprocesses_split_lines(): """Preprocess the text by splitting the lines between min-length and max_length I don't like this step: @@ -429,97 +430,105 @@ def preprocesses_split_lines(): Important NGRAMs are cut (though given enough data, that might be moot). I do this to enable batch-learning by padding to a fixed length. """ - LOGGER.info("Reading filtered data:") - answers = set() - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - for _line in open(NEWS_FILE_NAME_FILTERED): - line = _line.decode('utf-8') - while len(line) > MIN_INPUT_LEN: - if len(line) <= CONFIG.max_input_len: - answer = line - line = "" - else: - space_location = line.rfind(" ", MIN_INPUT_LEN, CONFIG.max_input_len - 1) - if space_location > -1: - answer = line[:space_location] - line = line[len(answer) + 1:] + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + logging.info("preprocesses_split_lines: Reading filtered data:") + answers = set() + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + for _line in open(NEWS_FILE_NAME_FILTERED): + line = _line + while len(line) > MIN_INPUT_LEN: + if len(line) <= CONFIG.max_input_len: + answer = line + line = "" else: - space_location = line.rfind(" ") # no limits this time - if space_location == -1: - break # we are done with this line + space_location = line.rfind(" ", MIN_INPUT_LEN, CONFIG.max_input_len - 1) + if space_location > -1: + answer = line[:space_location] + line = line[len(answer) + 1:] else: - line = line[space_location + 1:] - continue - answers.add(answer) - output_file.write(answer.encode('utf-8') + b"\n") + space_location = line.rfind(" ") # no limits this time + if space_location == -1: + break # we are done with this line + else: + line = line[space_location + 1:] + continue + answers.add(answer) + output_file.write(answer + "\n") + def preprocesses_split_lines2(): """Preprocess the text by splitting the lines between min-length and max_length Alternative split. """ - LOGGER.info("Reading filtered data:") - answers = set() - for encoded_line in open(NEWS_FILE_NAME_FILTERED): - line = encoded_line.decode('utf-8') - if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: - answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") - for answer in itertools.islice(answers, 10): - LOGGER.info(answer) - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - output_file.write("".join(answers).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + logging.info("Reading filtered data:") + answers = set() + for encoded_line in open(NEWS_FILE_NAME_FILTERED): + line = encoded_line + if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: + answers.add(line) + logging.info("There are %s 'answers' (sub-sentences)", len(answers)) + logging.info("Here are some examples:") + for answer in itertools.islice(answers, 10): + logging.info(answer) + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + output_file.write("".join(answers)) + def preprocesses_split_lines3(): """Preprocess the text by selecting only max n-grams Alternative split. """ - LOGGER.info("Reading filtered data:") - answers = set() - for encoded_line in open(NEWS_FILE_NAME_FILTERED): - line = encoded_line.decode('utf-8') - if line.count(" ") < 5: - answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") - for answer in itertools.islice(answers, 10): - LOGGER.info(answer) - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - output_file.write("".join(answers).encode('utf-8')) - -def preprocesses_split_lines4(): - """Preprocess the text by selecting only sentences with most-common words AND not too long - Alternative split. - """ - LOGGER.info("Reading filtered data:") - from gensim.models.word2vec import Word2Vec - FILTERED_W2V = "fw2v.bin" - model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format - print(len(model.wv.index2word)) -# answers = set() -# for encoded_line in open(NEWS_FILE_NAME_FILTERED): -# line = encoded_line.decode('utf-8') -# if line.count(" ") < 5: -# answers.add(line) -# LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) -# LOGGER.info("Here are some examples:") -# for answer in itertools.islice(answers, 10): -# LOGGER.info(answer) -# with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: -# output_file.write("".join(answers).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + logging.info("Reading filtered data:") + answers = set() + for encoded_line in open(NEWS_FILE_NAME_FILTERED): + line = encoded_line + if line.count(" ") < 5: + answers.add(line) + logging.info("There are %s 'answers' (sub-sentences)", len(answers)) + logging.info("Here are some examples:") + for answer in itertools.islice(answers, 10): + logging.info(answer) + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + output_file.write("".join(answers)) + + +# def preprocesses_split_lines4(): +# """Preprocess the text by selecting only sentences with most-common words AND not too long +# Alternative split. +# """ +# logging.info("Reading filtered data:") +# from gensim.models.word2vec import Word2Vec +# FILTERED_W2V = "fw2v.bin" +# model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format +# print(len(model.wv.index2word)) +# # answers = set() +# # for encoded_line in open(NEWS_FILE_NAME_FILTERED): +# # line = encoded_line.decode('utf-8') +# # if line.count(" ") < 5: +# # answers.add(line) +# # logging.info("There are %s 'answers' (sub-sentences)", len(answers)) +# # logging.info("Here are some examples:") +# # for answer in itertools.islice(answers, 10): +# # logging.info(answer) +# # with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: +# # output_file.write("".join(answers)) def preprocess_partition_data(): """Set asside data for validation""" - answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") - print('shuffle', end=" ") - random_shuffle(answers) - print("Done") - # Explicitly set apart 10% for validation data that we never train over - split_at = len(answers) - len(answers) // 10 - with open(NEWS_FILE_NAME_TRAIN, "wb") as output_file: - output_file.write("\n".join(answers[:split_at]).encode('utf-8')) - with open(NEWS_FILE_NAME_VALIDATE, "wb") as output_file: - output_file.write("\n".join(answers[split_at:]).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_TRAIN): + logging.info("preprocess_partition_data") + answers = open(NEWS_FILE_NAME_SPLIT).read().split("\n") + print('shuffle', end=" ") + np.random.shuffle(answers) + print("Done") + # Explicitly set apart 10% for validation data that we never train over + split_at = len(answers) - len(answers) // 10 + with open(NEWS_FILE_NAME_TRAIN, "wt") as output_file: + output_file.write("\n".join(answers[:split_at])) + with open(NEWS_FILE_NAME_VALIDATE, "wt") as output_file: + output_file.write("\n".join(answers[split_at:])) def generate_question(answer): @@ -530,19 +539,20 @@ def generate_question(answer): answer += PADDING * (CONFIG.max_input_len - len(answer)) return question, answer + def generate_news_data(): """Generate some news data""" print ("Generating Data") - answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") + answers = open(NEWS_FILE_NAME_SPLIT).read().split("\n") questions = [] print('shuffle', end=" ") - random_shuffle(answers) + np.random.shuffle(answers) print("Done") for answer_index, answer in enumerate(answers): question, answer = generate_question(answer) answers[answer_index] = answer assert len(answer) == CONFIG.max_input_len - if random_randint(100000) == 8: # Show some progress + if random.randint(0, 100000) == 8: # Show some progress print (len(answers)) print ("answer: '{}'".format(answer)) print ("question: '{}'".format(question)) @@ -552,6 +562,7 @@ def generate_news_data(): return questions, answers + def train_speller_w_all_data(): """Train the speller if all data fits into RAM""" questions, answers = generate_news_data() @@ -563,23 +574,43 @@ def train_speller_w_all_data(): model = generate_model(y_maxlen, chars) iterate_training(model, X_train, y_train, X_val, y_val, ctable) + def train_speller(from_file=None): """Train the speller""" if from_file: + logging.info("reloading existing model before training " + from_file) model = load_model(from_file) else: + logging.info("training a new model") model = generate_model(CONFIG.max_input_len, chars=read_top_chars()) - itarative_train(model) + iterative_train(model) + if __name__ == '__main__': -# download_the_news_data() -# uncompress_data() -# preprocesses_data_clean() -# preprocesses_data_analyze_chars() -# preprocesses_data_filter() -# preprocesses_split_lines() --- Choose this step or: + download_the_news_data() # download data/news.2013.en.shuffled.gz if dne + uncompress_data() + preprocesses_data_clean() + preprocesses_data_analyze_chars() + preprocesses_data_filter() + preprocesses_split_lines() # preprocesses_split_lines2() # preprocesses_split_lines4() -# preprocess_partition_data() -# train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5")) - train_speller() + preprocess_partition_data() + + + # get the latest epoch of any previous runs and use it as a starting model if it exists + previous_models = [os.path.join(DATA_FILES_FULL_PATH, f) for f in os.listdir(DATA_FILES_FULL_PATH) + if os.path.isfile(os.path.join(DATA_FILES_FULL_PATH, f)) and f.startswith("keras_spell_e")] + last_model = None + latest_model_number = -1 + for model in previous_models: + epoch = int(model.split("_")[-1].split(".")[0][1:]) + if epoch > latest_model_number: + latest_model_number = epoch + last_model = model + + # run with previous model or anew? + if last_model is not None and os.path.isfile(last_model): + train_speller(last_model) + else: + train_speller() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ff497b7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy +keras +tensorflow +requests +h5py