From 56fe0bb24a06fb4ff6370dbd7e82c8315ce87795 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 13:59:18 +1200 Subject: [PATCH 01/10] fixes for python 3.5 --- .gitignore | 2 + README.md | 8 +- keras_spell.py | 362 +++++++++++++++++++++++------------------------ requirements.txt | 4 + 4 files changed, 186 insertions(+), 190 deletions(-) create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index e038e74..55bebb4 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,5 @@ target/ .ipynb_checkpoints /.project /.pydevproject + +data diff --git a/README.md b/README.md index edabccd..c094c0f 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,5 @@ a Deep Learning based Speller See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm - -Additional details: - -I used this AMI to train the system: -https://aws.amazon.com/marketplace/pp/B06VSPXKDX -On a p2.xlarge instance (currently at $0.9 per Hour) +## Python 3.5 +updated code to run in Python 3.5 diff --git a/keras_spell.py b/keras_spell.py index 7442d9d..ed08e10 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -1,29 +1,13 @@ -# encoding: utf-8 -''' -Created on Nov 26, 2015 - -@author: tal - -Based in part on: -Learn math - https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py - -See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm -''' - -from __future__ import print_function, division, unicode_literals - import os import errno from collections import Counter -from hashlib import sha256 import re import json import itertools import logging import requests import numpy as np -from numpy.random import choice as random_choice, randint as random_randint, shuffle as random_shuffle, seed as random_seed, rand -from numpy import zeros as np_zeros # pylint:disable=no-name-in-module +import random from keras.models import Sequential, load_model from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent @@ -34,7 +18,7 @@ LOGGER.addHandler(logging.StreamHandler()) LOGGER.setLevel(logging.DEBUG) -random_seed(123) # Reproducibility +random.seed(123) # Reproducibility class Configuration(object): """Dump stuff here""" @@ -57,9 +41,6 @@ class Configuration(object): CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. CONFIG.validation_steps = 10 CONFIG.number_of_iterations = 10 -#pylint:enable=attribute-defined-outside-init - -DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest() # Parameters for the dataset MIN_INPUT_LEN = 5 @@ -67,8 +48,8 @@ class Configuration(object): CHARS = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .") PADDING = "☕" -DATA_FILES_PATH = "~/Downloads/data" -DATA_FILES_FULL_PATH = os.path.expanduser(DATA_FILES_PATH) +DATA_FILES_PATH = "data" +DATA_FILES_FULL_PATH = os.path.join(os.path.dirname(__file__), DATA_FILES_PATH) DATA_FILES_URL = "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz" NEWS_FILE_NAME_COMPRESSED = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.shuffled.gz") # 1.1 GB NEWS_FILE_NAME_ENGLISH = "news.2013.en.shuffled" @@ -84,9 +65,9 @@ class Configuration(object): # Some cleanup: NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) # match all whitespace except newlines RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) -RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(unichr(768), unichr(769), unichr(832), - unichr(833), unichr(2387), unichr(5151), - unichr(5152), unichr(65344), unichr(8242)), +RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(chr(768), chr(769), chr(832), + chr(833), chr(2387), chr(5151), + chr(5152), chr(65344), chr(8242)), re.UNICODE) RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) @@ -98,51 +79,54 @@ class Configuration(object): def download_the_news_data(): """Download the news data""" - LOGGER.info("Downloading") - try: - os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise - with open(NEWS_FILE_NAME_COMPRESSED, "wb") as output_file: - response = requests.get(DATA_FILES_URL, stream=True) - total_length = response.headers.get('content-length') - downloaded = percentage = 0 - print("»"*100) - total_length = int(total_length) - for data in response.iter_content(chunk_size=4096): - downloaded += len(data) - output_file.write(data) - new_percentage = 100 * downloaded // total_length - if new_percentage > percentage: - print("☑", end="") - percentage = new_percentage - print() + if not os.path.isfile(NEWS_FILE_NAME_COMPRESSED): + LOGGER.info("Downloading") + try: + os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + with open(NEWS_FILE_NAME_COMPRESSED, "wb") as output_file: + response = requests.get(DATA_FILES_URL, stream=True) + total_length = response.headers.get('content-length') + downloaded = percentage = 0 + print("»"*100) + total_length = int(total_length) + for data in response.iter_content(chunk_size=4096): + downloaded += len(data) + output_file.write(data) + new_percentage = 100 * downloaded // total_length + if new_percentage > percentage: + print("☑", end="") + percentage = new_percentage + print() def uncompress_data(): """Uncompress the data files""" import gzip - with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: - with open(NEWS_FILE_NAME_COMPRESSED[:-3], 'wb') as outfile: - outfile.write(compressed_file.read()) + out_filename = NEWS_FILE_NAME_COMPRESSED[:-3] + if not os.path.isfile(out_filename): + with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: + with open(out_filename, 'wb') as outfile: + outfile.write(compressed_file.read()) def add_noise_to_string(a_string, amount_of_noise): """Add some artificial spelling mistakes to the string""" - if rand() < amount_of_noise * len(a_string): + if random() < amount_of_noise * len(a_string): # Replace a character with a random character - random_char_position = random_randint(len(a_string)) - a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position + 1:] - if rand() < amount_of_noise * len(a_string): + random_char_position = random.randint(len(a_string)) + a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position + 1:] + if random() < amount_of_noise * len(a_string): # Delete a character - random_char_position = random_randint(len(a_string)) + random_char_position = random.randint(len(a_string)) a_string = a_string[:random_char_position] + a_string[random_char_position + 1:] - if len(a_string) < CONFIG.max_input_len and rand() < amount_of_noise * len(a_string): + if len(a_string) < CONFIG.max_input_len and random() < amount_of_noise * len(a_string): # Add a random character - random_char_position = random_randint(len(a_string)) - a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position:] - if rand() < amount_of_noise * len(a_string): + random_char_position = random.randint(len(a_string)) + a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position:] + if random() < amount_of_noise * len(a_string): # Transpose 2 characters - random_char_position = random_randint(len(a_string) - 1) + random_char_position = random.randint(len(a_string) - 1) a_string = (a_string[:random_char_position] + a_string[random_char_position + 1] + a_string[random_char_position] + a_string[random_char_position + 2:]) return a_string @@ -150,16 +134,16 @@ def add_noise_to_string(a_string, amount_of_noise): def _vectorize(questions, answers, ctable): """Vectorize the data as numpy arrays""" len_of_questions = len(questions) - X = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(questions)): + X = np.zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) + for i in range(len(questions)): sentence = questions.pop() for j, c in enumerate(sentence): try: X[i, j, ctable.char_indices[c]] = 1 except KeyError: pass # Padding - y = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(answers)): + y = np.zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) + for i in range(len(answers)): sentence = answers.pop() for j, c in enumerate(sentence): try: @@ -266,7 +250,7 @@ def size(self): def encode(self, C, maxlen): """Encode as one-hot""" - X = np_zeros((maxlen, len(self.chars)), dtype=np.bool) # pylint:disable=no-member + X = np.zeros((maxlen, len(self.chars)), dtype=np.bool) # pylint:disable=no-member for i, c in enumerate(C): X[i, self.char_indices[c]] = 1 return X @@ -288,9 +272,9 @@ def generator(file_name): while True: with open(file_name) as answers: for answer in answers: - batch_of_answers.append(answer.strip().decode('utf-8')) + batch_of_answers.append(answer.strip()) if len(batch_of_answers) == CONFIG.batch_size: - random_shuffle(batch_of_answers) + np.random.shuffle(batch_of_answers) batch_of_questions = [] for answer_index, answer in enumerate(batch_of_answers): question, answer = generate_question(answer) @@ -306,8 +290,8 @@ def print_random_predictions(model, ctable, X_val, y_val): """Select 10 samples from the validation set at random so we can visualize errors""" print() for _ in range(10): - ind = random_randint(0, len(X_val)) - rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])] # pylint:disable=no-member + ind = random.randint(0, len(X_val)) + rowX, rowy = X_val[np.asarray([ind])], y_val[np.asarray([ind])] # pylint:disable=no-member preds = model.predict_classes(rowX, verbose=0) q = ctable.decode(rowX[0]) correct = ctable.decode(rowy[0]) @@ -369,31 +353,31 @@ def clean_text(text): result = RE_BASIC_CLEANER.sub('', result) return result + def preprocesses_data_clean(): """Pre-process the data - step 1 - cleanup""" - with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data: - for line in open(NEWS_FILE_NAME): - decoded_line = line.decode('utf-8') - cleaned_line = clean_text(decoded_line) - encoded_line = cleaned_line.encode("utf-8") - clean_data.write(encoded_line + b"\n") + if not os.path.isfile(NEWS_FILE_NAME_CLEAN): + with open(NEWS_FILE_NAME_CLEAN, "wt") as clean_data: + for line in open(NEWS_FILE_NAME): + clean_data.write(clean_text(line) + "\n") + def preprocesses_data_analyze_chars(): """Pre-process the data - step 2 - analyze the characters""" - counter = Counter() - LOGGER.info("Reading data:") - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - counter.update(decoded_line) -# data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') -# LOGGER.info("Read.\nCounting characters:") -# counter = Counter(data.replace("\n", "")) - LOGGER.info("Done.\nWriting to file:") - with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file: - output_file.write(json.dumps(counter)) - most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} - LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) - LOGGER.info("".join(sorted(most_popular_chars))) + if not os.path.isfile(CHAR_FREQUENCY_FILE_NAME): + counter = Counter() + LOGGER.info("Reading data:") + for line in open(NEWS_FILE_NAME_CLEAN): + counter.update(line) + # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') + # LOGGER.info("Read.\nCounting characters:") + # counter = Counter(data.replace("\n", "")) + LOGGER.info("Done.\nWriting to file:") + with open(CHAR_FREQUENCY_FILE_NAME, 'wt') as output_file: + output_file.write(json.dumps(counter)) + most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} + LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) + LOGGER.info("".join(sorted(most_popular_chars))) def read_top_chars(): """Read the top chars we saved to file""" @@ -404,19 +388,20 @@ def read_top_chars(): def preprocesses_data_filter(): """Pre-process the data - step 3 - filter only sentences with the right chars""" - most_popular_chars = read_top_chars() - LOGGER.info("Reading and filtering data:") - with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file: - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - if decoded_line and not bool(set(decoded_line) - most_popular_chars): - output_file.write(line) - LOGGER.info("Done.") + if not os.path.isfile(NEWS_FILE_NAME_FILTERED): + most_popular_chars = read_top_chars() + LOGGER.info("Reading and filtering data:") + with open(NEWS_FILE_NAME_FILTERED, "wt") as output_file: + for line in open(NEWS_FILE_NAME_CLEAN): + decoded_line = line + if decoded_line and not bool(set(decoded_line) - most_popular_chars): + output_file.write(line) + LOGGER.info("Done.") def read_filtered_data(): """Read the filtered data corpus""" LOGGER.info("Reading filtered data:") - lines = open(NEWS_FILE_NAME_FILTERED).read().decode('utf-8').split("\n") + lines = open(NEWS_FILE_NAME_FILTERED).read().split("\n") LOGGER.info("Read filtered data - %s lines", len(lines)) return lines @@ -429,97 +414,103 @@ def preprocesses_split_lines(): Important NGRAMs are cut (though given enough data, that might be moot). I do this to enable batch-learning by padding to a fixed length. """ - LOGGER.info("Reading filtered data:") - answers = set() - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - for _line in open(NEWS_FILE_NAME_FILTERED): - line = _line.decode('utf-8') - while len(line) > MIN_INPUT_LEN: - if len(line) <= CONFIG.max_input_len: - answer = line - line = "" - else: - space_location = line.rfind(" ", MIN_INPUT_LEN, CONFIG.max_input_len - 1) - if space_location > -1: - answer = line[:space_location] - line = line[len(answer) + 1:] + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + LOGGER.info("Reading filtered data:") + answers = set() + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + for _line in open(NEWS_FILE_NAME_FILTERED): + line = _line + while len(line) > MIN_INPUT_LEN: + if len(line) <= CONFIG.max_input_len: + answer = line + line = "" else: - space_location = line.rfind(" ") # no limits this time - if space_location == -1: - break # we are done with this line + space_location = line.rfind(" ", MIN_INPUT_LEN, CONFIG.max_input_len - 1) + if space_location > -1: + answer = line[:space_location] + line = line[len(answer) + 1:] else: - line = line[space_location + 1:] - continue - answers.add(answer) - output_file.write(answer.encode('utf-8') + b"\n") + space_location = line.rfind(" ") # no limits this time + if space_location == -1: + break # we are done with this line + else: + line = line[space_location + 1:] + continue + answers.add(answer) + output_file.write(answer + "\n") def preprocesses_split_lines2(): """Preprocess the text by splitting the lines between min-length and max_length Alternative split. """ - LOGGER.info("Reading filtered data:") - answers = set() - for encoded_line in open(NEWS_FILE_NAME_FILTERED): - line = encoded_line.decode('utf-8') - if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: - answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") - for answer in itertools.islice(answers, 10): - LOGGER.info(answer) - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - output_file.write("".join(answers).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + LOGGER.info("Reading filtered data:") + answers = set() + for encoded_line in open(NEWS_FILE_NAME_FILTERED): + line = encoded_line + if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: + answers.add(line) + LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) + LOGGER.info("Here are some examples:") + for answer in itertools.islice(answers, 10): + LOGGER.info(answer) + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + output_file.write("".join(answers)) + def preprocesses_split_lines3(): """Preprocess the text by selecting only max n-grams Alternative split. """ - LOGGER.info("Reading filtered data:") - answers = set() - for encoded_line in open(NEWS_FILE_NAME_FILTERED): - line = encoded_line.decode('utf-8') - if line.count(" ") < 5: - answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") - for answer in itertools.islice(answers, 10): - LOGGER.info(answer) - with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: - output_file.write("".join(answers).encode('utf-8')) - -def preprocesses_split_lines4(): - """Preprocess the text by selecting only sentences with most-common words AND not too long - Alternative split. - """ - LOGGER.info("Reading filtered data:") - from gensim.models.word2vec import Word2Vec - FILTERED_W2V = "fw2v.bin" - model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format - print(len(model.wv.index2word)) -# answers = set() -# for encoded_line in open(NEWS_FILE_NAME_FILTERED): -# line = encoded_line.decode('utf-8') -# if line.count(" ") < 5: -# answers.add(line) -# LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) -# LOGGER.info("Here are some examples:") -# for answer in itertools.islice(answers, 10): -# LOGGER.info(answer) -# with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: -# output_file.write("".join(answers).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_SPLIT): + LOGGER.info("Reading filtered data:") + answers = set() + for encoded_line in open(NEWS_FILE_NAME_FILTERED): + line = encoded_line + if line.count(" ") < 5: + answers.add(line) + LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) + LOGGER.info("Here are some examples:") + for answer in itertools.islice(answers, 10): + LOGGER.info(answer) + with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: + output_file.write("".join(answers)) + + +# def preprocesses_split_lines4(): +# """Preprocess the text by selecting only sentences with most-common words AND not too long +# Alternative split. +# """ +# LOGGER.info("Reading filtered data:") +# from gensim.models.word2vec import Word2Vec +# FILTERED_W2V = "fw2v.bin" +# model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format +# print(len(model.wv.index2word)) +# # answers = set() +# # for encoded_line in open(NEWS_FILE_NAME_FILTERED): +# # line = encoded_line.decode('utf-8') +# # if line.count(" ") < 5: +# # answers.add(line) +# # LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) +# # LOGGER.info("Here are some examples:") +# # for answer in itertools.islice(answers, 10): +# # LOGGER.info(answer) +# # with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: +# # output_file.write("".join(answers)) def preprocess_partition_data(): """Set asside data for validation""" - answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") - print('shuffle', end=" ") - random_shuffle(answers) - print("Done") - # Explicitly set apart 10% for validation data that we never train over - split_at = len(answers) - len(answers) // 10 - with open(NEWS_FILE_NAME_TRAIN, "wb") as output_file: - output_file.write("\n".join(answers[:split_at]).encode('utf-8')) - with open(NEWS_FILE_NAME_VALIDATE, "wb") as output_file: - output_file.write("\n".join(answers[split_at:]).encode('utf-8')) + if not os.path.isfile(NEWS_FILE_NAME_TRAIN): + answers = open(NEWS_FILE_NAME_SPLIT).read().split("\n") + print('shuffle', end=" ") + np.random.shuffle(answers) + print("Done") + # Explicitly set apart 10% for validation data that we never train over + split_at = len(answers) - len(answers) // 10 + with open(NEWS_FILE_NAME_TRAIN, "wt") as output_file: + output_file.write("\n".join(answers[:split_at])) + with open(NEWS_FILE_NAME_VALIDATE, "wt") as output_file: + output_file.write("\n".join(answers[split_at:])) def generate_question(answer): @@ -536,13 +527,13 @@ def generate_news_data(): answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") questions = [] print('shuffle', end=" ") - random_shuffle(answers) + np.random.shuffle(answers) print("Done") for answer_index, answer in enumerate(answers): question, answer = generate_question(answer) answers[answer_index] = answer assert len(answer) == CONFIG.max_input_len - if random_randint(100000) == 8: # Show some progress + if random.randint(100000) == 8: # Show some progress print (len(answers)) print ("answer: '{}'".format(answer)) print ("question: '{}'".format(question)) @@ -572,14 +563,17 @@ def train_speller(from_file=None): itarative_train(model) if __name__ == '__main__': -# download_the_news_data() -# uncompress_data() -# preprocesses_data_clean() -# preprocesses_data_analyze_chars() -# preprocesses_data_filter() -# preprocesses_split_lines() --- Choose this step or: + download_the_news_data() + uncompress_data() + preprocesses_data_clean() + preprocesses_data_analyze_chars() + preprocesses_data_filter() + preprocesses_split_lines() # preprocesses_split_lines2() # preprocesses_split_lines4() -# preprocess_partition_data() -# train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5")) - train_speller() + preprocess_partition_data() + existing_model = os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5") + if os.path.isfile(existing_model): + train_speller(existing_model) + else: + train_speller() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..32eb854 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +numpy +keras +tensorflow +requests From d651643197575982427f97231e0e731c2e78b3fb Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 14:10:06 +1200 Subject: [PATCH 02/10] added logging fixes --- keras_spell.py | 66 +++++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/keras_spell.py b/keras_spell.py index ed08e10..35e4d7d 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -13,10 +13,8 @@ from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent from keras.callbacks import Callback -# Set a logger for the module -LOGGER = logging.getLogger(__name__) # Every log will use the module name -LOGGER.addHandler(logging.StreamHandler()) -LOGGER.setLevel(logging.DEBUG) +# setup console logging +logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) random.seed(123) # Reproducibility @@ -80,7 +78,7 @@ class Configuration(object): def download_the_news_data(): """Download the news data""" if not os.path.isfile(NEWS_FILE_NAME_COMPRESSED): - LOGGER.info("Downloading") + logging.info("download_the_news_data: Downloading " + DATA_FILES_URL) try: os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) except OSError as exception: @@ -106,10 +104,12 @@ def uncompress_data(): import gzip out_filename = NEWS_FILE_NAME_COMPRESSED[:-3] if not os.path.isfile(out_filename): + logging.info("uncompress_data: " + out_filename) with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: with open(out_filename, 'wb') as outfile: outfile.write(compressed_file.read()) + def add_noise_to_string(a_string, amount_of_noise): """Add some artificial spelling mistakes to the string""" if random() < amount_of_noise * len(a_string): @@ -318,6 +318,7 @@ def on_epoch_end(self, epoch, logs=None): ON_EPOCH_END_CALLBACK = OnEpochEndCallback() + def itarative_train(model): """ Iterative training of the model @@ -366,18 +367,19 @@ def preprocesses_data_analyze_chars(): """Pre-process the data - step 2 - analyze the characters""" if not os.path.isfile(CHAR_FREQUENCY_FILE_NAME): counter = Counter() - LOGGER.info("Reading data:") + logging.info("preprocesses_data_analyze_chars: Reading data:") for line in open(NEWS_FILE_NAME_CLEAN): counter.update(line) # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') - # LOGGER.info("Read.\nCounting characters:") + # logging.info("Read.\nCounting characters:") # counter = Counter(data.replace("\n", "")) - LOGGER.info("Done.\nWriting to file:") + logging.info("Done.\nWriting to file:") with open(CHAR_FREQUENCY_FILE_NAME, 'wt') as output_file: output_file.write(json.dumps(counter)) most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} - LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) - LOGGER.info("".join(sorted(most_popular_chars))) + logging.info("The top %s chars are:", CONFIG.number_of_chars) + logging.info("".join(sorted(most_popular_chars))) + def read_top_chars(): """Read the top chars we saved to file""" @@ -386,25 +388,28 @@ def read_top_chars(): most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} return most_popular_chars + def preprocesses_data_filter(): """Pre-process the data - step 3 - filter only sentences with the right chars""" if not os.path.isfile(NEWS_FILE_NAME_FILTERED): most_popular_chars = read_top_chars() - LOGGER.info("Reading and filtering data:") + logging.info("preprocesses_data_filter: Reading and filtering data:") with open(NEWS_FILE_NAME_FILTERED, "wt") as output_file: for line in open(NEWS_FILE_NAME_CLEAN): decoded_line = line if decoded_line and not bool(set(decoded_line) - most_popular_chars): output_file.write(line) - LOGGER.info("Done.") + logging.info("Done.") + def read_filtered_data(): """Read the filtered data corpus""" - LOGGER.info("Reading filtered data:") + logging.info("Reading filtered data:") lines = open(NEWS_FILE_NAME_FILTERED).read().split("\n") - LOGGER.info("Read filtered data - %s lines", len(lines)) + logging.info("Read filtered data - %s lines", len(lines)) return lines + def preprocesses_split_lines(): """Preprocess the text by splitting the lines between min-length and max_length I don't like this step: @@ -415,7 +420,7 @@ def preprocesses_split_lines(): I do this to enable batch-learning by padding to a fixed length. """ if not os.path.isfile(NEWS_FILE_NAME_SPLIT): - LOGGER.info("Reading filtered data:") + logging.info("preprocesses_split_lines: Reading filtered data:") answers = set() with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: for _line in open(NEWS_FILE_NAME_FILTERED): @@ -439,21 +444,22 @@ def preprocesses_split_lines(): answers.add(answer) output_file.write(answer + "\n") + def preprocesses_split_lines2(): """Preprocess the text by splitting the lines between min-length and max_length Alternative split. """ if not os.path.isfile(NEWS_FILE_NAME_SPLIT): - LOGGER.info("Reading filtered data:") + logging.info("Reading filtered data:") answers = set() for encoded_line in open(NEWS_FILE_NAME_FILTERED): line = encoded_line if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") + logging.info("There are %s 'answers' (sub-sentences)", len(answers)) + logging.info("Here are some examples:") for answer in itertools.islice(answers, 10): - LOGGER.info(answer) + logging.info(answer) with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: output_file.write("".join(answers)) @@ -463,16 +469,16 @@ def preprocesses_split_lines3(): Alternative split. """ if not os.path.isfile(NEWS_FILE_NAME_SPLIT): - LOGGER.info("Reading filtered data:") + logging.info("Reading filtered data:") answers = set() for encoded_line in open(NEWS_FILE_NAME_FILTERED): line = encoded_line if line.count(" ") < 5: answers.add(line) - LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) - LOGGER.info("Here are some examples:") + logging.info("There are %s 'answers' (sub-sentences)", len(answers)) + logging.info("Here are some examples:") for answer in itertools.islice(answers, 10): - LOGGER.info(answer) + logging.info(answer) with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: output_file.write("".join(answers)) @@ -481,7 +487,7 @@ def preprocesses_split_lines3(): # """Preprocess the text by selecting only sentences with most-common words AND not too long # Alternative split. # """ -# LOGGER.info("Reading filtered data:") +# logging.info("Reading filtered data:") # from gensim.models.word2vec import Word2Vec # FILTERED_W2V = "fw2v.bin" # model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format @@ -491,16 +497,17 @@ def preprocesses_split_lines3(): # # line = encoded_line.decode('utf-8') # # if line.count(" ") < 5: # # answers.add(line) -# # LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) -# # LOGGER.info("Here are some examples:") +# # logging.info("There are %s 'answers' (sub-sentences)", len(answers)) +# # logging.info("Here are some examples:") # # for answer in itertools.islice(answers, 10): -# # LOGGER.info(answer) +# # logging.info(answer) # # with open(NEWS_FILE_NAME_SPLIT, "wt") as output_file: # # output_file.write("".join(answers)) def preprocess_partition_data(): """Set asside data for validation""" if not os.path.isfile(NEWS_FILE_NAME_TRAIN): + logging.info("preprocess_partition_data") answers = open(NEWS_FILE_NAME_SPLIT).read().split("\n") print('shuffle', end=" ") np.random.shuffle(answers) @@ -543,6 +550,7 @@ def generate_news_data(): return questions, answers + def train_speller_w_all_data(): """Train the speller if all data fits into RAM""" questions, answers = generate_news_data() @@ -554,14 +562,18 @@ def train_speller_w_all_data(): model = generate_model(y_maxlen, chars) iterate_training(model, X_train, y_train, X_val, y_val, ctable) + def train_speller(from_file=None): """Train the speller""" if from_file: + logging.info("reloading existing model before training " + from_file) model = load_model(from_file) else: + logging.info("training a new model") model = generate_model(CONFIG.max_input_len, chars=read_top_chars()) itarative_train(model) + if __name__ == '__main__': download_the_news_data() uncompress_data() From 0410a95596608166c8ec98f94e6d80e1631f8b34 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 14:17:57 +1200 Subject: [PATCH 03/10] cleaned up gitignore --- .gitignore | 64 +++--------------------------------------------------- 1 file changed, 3 insertions(+), 61 deletions(-) diff --git a/.gitignore b/.gitignore index 55bebb4..74268d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,66 +1,8 @@ -# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class -# C extensions -*.so +.idea/ +.env/ -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -#Ipython Notebook -.ipynb_checkpoints -/.project -/.pydevproject - -data +data/ From 9dd66d80f1e444d5f5e005a0074d95d8d51ed8da Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 14:52:08 +1200 Subject: [PATCH 04/10] fixed syntax errors --- keras_spell.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/keras_spell.py b/keras_spell.py index 35e4d7d..89efa1c 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -112,21 +112,21 @@ def uncompress_data(): def add_noise_to_string(a_string, amount_of_noise): """Add some artificial spelling mistakes to the string""" - if random() < amount_of_noise * len(a_string): + if random.random() < amount_of_noise * len(a_string): # Replace a character with a random character - random_char_position = random.randint(len(a_string)) + random_char_position = random.randint(0, len(a_string)) a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position + 1:] - if random() < amount_of_noise * len(a_string): + if random.random() < amount_of_noise * len(a_string): # Delete a character - random_char_position = random.randint(len(a_string)) + random_char_position = random.randint(0, len(a_string)) a_string = a_string[:random_char_position] + a_string[random_char_position + 1:] - if len(a_string) < CONFIG.max_input_len and random() < amount_of_noise * len(a_string): + if len(a_string) < CONFIG.max_input_len and random.random() < amount_of_noise * len(a_string): # Add a random character - random_char_position = random.randint(len(a_string)) + random_char_position = random.randint(0, len(a_string)) a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position:] - if random() < amount_of_noise * len(a_string): + if random.random() < amount_of_noise * len(a_string): # Transpose 2 characters - random_char_position = random.randint(len(a_string) - 1) + random_char_position = random.randint(0, len(a_string) - 1) a_string = (a_string[:random_char_position] + a_string[random_char_position + 1] + a_string[random_char_position] + a_string[random_char_position + 2:]) return a_string @@ -358,6 +358,7 @@ def clean_text(text): def preprocesses_data_clean(): """Pre-process the data - step 1 - cleanup""" if not os.path.isfile(NEWS_FILE_NAME_CLEAN): + logging.info("preprocesses_data_clean") with open(NEWS_FILE_NAME_CLEAN, "wt") as clean_data: for line in open(NEWS_FILE_NAME): clean_data.write(clean_text(line) + "\n") @@ -528,10 +529,11 @@ def generate_question(answer): answer += PADDING * (CONFIG.max_input_len - len(answer)) return question, answer + def generate_news_data(): """Generate some news data""" print ("Generating Data") - answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") + answers = open(NEWS_FILE_NAME_SPLIT).read().split("\n") questions = [] print('shuffle', end=" ") np.random.shuffle(answers) @@ -540,7 +542,7 @@ def generate_news_data(): question, answer = generate_question(answer) answers[answer_index] = answer assert len(answer) == CONFIG.max_input_len - if random.randint(100000) == 8: # Show some progress + if random.randint(0, 100000) == 8: # Show some progress print (len(answers)) print ("answer: '{}'".format(answer)) print ("question: '{}'".format(question)) @@ -575,7 +577,7 @@ def train_speller(from_file=None): if __name__ == '__main__': - download_the_news_data() + download_the_news_data() # download data/news.2013.en.shuffled.gz if dne uncompress_data() preprocesses_data_clean() preprocesses_data_analyze_chars() From 542a1d74ccfb488b0821858aed18a2ac3947a1ba Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 14:54:41 +1200 Subject: [PATCH 05/10] fixed bug --- keras_spell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_spell.py b/keras_spell.py index 89efa1c..4f9e930 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -126,7 +126,7 @@ def add_noise_to_string(a_string, amount_of_noise): a_string = a_string[:random_char_position] + random.choice(CHARS[:-1]) + a_string[random_char_position:] if random.random() < amount_of_noise * len(a_string): # Transpose 2 characters - random_char_position = random.randint(0, len(a_string) - 1) + random_char_position = random.randint(0, len(a_string) - 2) a_string = (a_string[:random_char_position] + a_string[random_char_position + 1] + a_string[random_char_position] + a_string[random_char_position + 2:]) return a_string From e80da921fb5ab3ecd0515762c9a6c01006d48d97 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 14:57:48 +1200 Subject: [PATCH 06/10] fixed typos, spacing and seperations --- keras_spell.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/keras_spell.py b/keras_spell.py index 4f9e930..bf34217 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -18,25 +18,25 @@ random.seed(123) # Reproducibility + class Configuration(object): """Dump stuff here""" + CONFIG = Configuration() -#pylint:disable=attribute-defined-outside-init -# Parameters for the model: CONFIG.input_layers = 2 CONFIG.output_layers = 2 CONFIG.amount_of_dropout = 0.2 CONFIG.hidden_size = 500 -CONFIG.initialization = "he_normal" # : Gaussian initialization scaled by fan-in (He et al., 2014) +CONFIG.initialization = "he_normal" # : Gaussian initialization scaled by fan-in (He et al., 2014) CONFIG.number_of_chars = 100 CONFIG.max_input_len = 60 CONFIG.inverted = True # parameters for the training: -CONFIG.batch_size = 100 # As the model changes in size, play with the batch size to best fit the process in memory -CONFIG.epochs = 500 # due to mini-epochs. -CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. +CONFIG.batch_size = 100 # As the model changes in size, play with the batch size to best fit the process in memory +CONFIG.epochs = 500 # due to mini-epochs. +CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. CONFIG.validation_steps = 10 CONFIG.number_of_iterations = 10 @@ -131,6 +131,7 @@ def add_noise_to_string(a_string, amount_of_noise): a_string[random_char_position + 2:]) return a_string + def _vectorize(questions, answers, ctable): """Vectorize the data as numpy arrays""" len_of_questions = len(questions) @@ -152,6 +153,7 @@ def _vectorize(questions, answers, ctable): pass # Padding return X, y + def slice_X(X, start=None, stop=None): """This takes an array-like, or a list of array-likes, and outputs: @@ -180,6 +182,7 @@ def slice_X(X, start=None, stop=None): else: return X[start:stop] + def vectorize(questions, answers, chars=None): """Vectorize the questions and expected answers""" print('Vectorization...') @@ -261,6 +264,7 @@ def decode(self, X, calc_argmax=True): X = X.argmax(axis=-1) return ''.join(self.indices_char[x] for x in X if x) + def generator(file_name): """Returns a tuple (inputs, targets) All arrays should contain the same number of samples. @@ -286,6 +290,7 @@ def generator(file_name): yield X, y batch_of_answers = [] + def print_random_predictions(model, ctable, X_val, y_val): """Select 10 samples from the validation set at random so we can visualize errors""" print() @@ -316,10 +321,11 @@ def on_epoch_end(self, epoch, logs=None): print_random_predictions(self.model, ctable, X_val, y_val) self.model.save(SAVED_MODEL_FILE_NAME.format(epoch)) + ON_EPOCH_END_CALLBACK = OnEpochEndCallback() -def itarative_train(model): +def iterative_train(model): """ Iterative training of the model - To allow for finite RAM... @@ -344,6 +350,7 @@ def iterate_training(model, X_train, y_train, X_val, y_val, ctable): validation_data=(X_val, y_val)) print_random_predictions(model, ctable, X_val, y_val) + def clean_text(text): """Clean the text - remove unwanted chars, fold punctuation etc.""" result = NORMALIZE_WHITESPACE_REGEX.sub(' ', text.strip()) @@ -573,7 +580,7 @@ def train_speller(from_file=None): else: logging.info("training a new model") model = generate_model(CONFIG.max_input_len, chars=read_top_chars()) - itarative_train(model) + iterative_train(model) if __name__ == '__main__': From 4e718e1cfdf3fe01ccc405c08482ba1413288140 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 16:24:48 +1200 Subject: [PATCH 07/10] added missing h5py --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 32eb854..ff497b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy keras tensorflow requests +h5py From ec24d6baafc662d3a0b145cdd62593556a984665 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 23:34:11 +1200 Subject: [PATCH 08/10] added exception handler for bad print_random_predictions code --- keras_spell.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/keras_spell.py b/keras_spell.py index bf34217..2c6b1df 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -294,20 +294,23 @@ def generator(file_name): def print_random_predictions(model, ctable, X_val, y_val): """Select 10 samples from the validation set at random so we can visualize errors""" print() - for _ in range(10): - ind = random.randint(0, len(X_val)) - rowX, rowy = X_val[np.asarray([ind])], y_val[np.asarray([ind])] # pylint:disable=no-member - preds = model.predict_classes(rowX, verbose=0) - q = ctable.decode(rowX[0]) - correct = ctable.decode(rowy[0]) - guess = ctable.decode(preds[0], calc_argmax=False) - if CONFIG.inverted: - print('Q', q[::-1]) # inverted back! - else: - print('Q', q) - print('A', correct) - print(Colors.green + '☑' + Colors.close if correct == guess else Colors.red + '☒' + Colors.close, guess) - print('---') + try: + for _ in range(10): + ind = random.randint(0, len(X_val)) + rowX, rowy = X_val[np.asarray([ind])], y_val[np.asarray([ind])] # pylint:disable=no-member + preds = model.predict_classes(rowX, verbose=0) + q = ctable.decode(rowX[0]) + correct = ctable.decode(rowy[0]) + guess = ctable.decode(preds[0], calc_argmax=False) + if CONFIG.inverted: + print('Q', q[::-1]) # inverted back! + else: + print('Q', q) + print('A', correct) + print(Colors.green + '☑' + Colors.close if correct == guess else Colors.red + '☒' + Colors.close, guess) + print('---') + except Exception: + pass print() From 51fdb2aab2499cd3383f572bede00540ec6b1d12 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 23:41:58 +1200 Subject: [PATCH 09/10] added epoch loader --- keras_spell.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/keras_spell.py b/keras_spell.py index 2c6b1df..bb1147d 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -596,8 +596,21 @@ def train_speller(from_file=None): # preprocesses_split_lines2() # preprocesses_split_lines4() preprocess_partition_data() - existing_model = os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5") - if os.path.isfile(existing_model): - train_speller(existing_model) + + + # get the latest epoch of any previous runs and use it as a starting model if it exists + previous_models = [os.path.join(DATA_FILES_FULL_PATH, f) for f in os.listdir(DATA_FILES_FULL_PATH) + if os.path.isfile(os.path.join(DATA_FILES_FULL_PATH, f)) and f.startswith("keras_spell_e")] + last_model = None + latest_model_number = -1 + for model in previous_models: + epoch = int(model.split("_")[-1].split(".")[0][1:]) + if epoch > latest_model_number: + latest_model = epoch + last_model = model + + # run with previous model or anew? + if last_model is not None and os.path.isfile(last_model): + train_speller(last_model) else: train_speller() From ad072ef7fc94298d6bc75d2d92d4d94c199feef5 Mon Sep 17 00:00:00 2001 From: Rock de Vocht Date: Sun, 8 Apr 2018 23:48:06 +1200 Subject: [PATCH 10/10] fixed bug --- keras_spell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_spell.py b/keras_spell.py index bb1147d..4e0d7ee 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -606,7 +606,7 @@ def train_speller(from_file=None): for model in previous_models: epoch = int(model.split("_")[-1].split(".")[0][1:]) if epoch > latest_model_number: - latest_model = epoch + latest_model_number = epoch last_model = model # run with previous model or anew?