diff --git a/nmt/train.py b/nmt/train.py index 21f11b8d6..691bcd9bb 100644 --- a/nmt/train.py +++ b/nmt/train.py @@ -447,8 +447,8 @@ def _sample_decode(model, global_step, sess, hparams, iterator, src_data, sent_id=0, tgt_eos=hparams.eos, subword_option=hparams.subword_option) - utils.print_out(" src: %s" % src_data[decode_id]) - utils.print_out(" ref: %s" % tgt_data[decode_id]) + utils.print_out(b" src: " + utils.format_sentence(src_data[decode_id], hparams.subword_option)) + utils.print_out(b" ref: " + utils.format_sentence(tgt_data[decode_id], hparams.subword_option)) utils.print_out(b" nmt: " + translation) # Summary diff --git a/nmt/utils/misc_utils.py b/nmt/utils/misc_utils.py index a680a5cf2..5fabc7ec4 100644 --- a/nmt/utils/misc_utils.py +++ b/nmt/utils/misc_utils.py @@ -164,8 +164,6 @@ def format_bpe_text(symbols, delimiter=b"@@"): """Convert a sequence of bpe words into sentence.""" words = [] word = b"" - if isinstance(symbols, str): - symbols = symbols.encode() delimiter_len = len(delimiter) for symbol in symbols: if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: @@ -181,3 +179,17 @@ def format_spm_text(symbols): """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" return u"".join(format_text(symbols).decode("utf-8").split()).replace( u"\u2581", u" ").strip().encode("utf-8") + +def format_sentence(sentence, subword_option): + """Decode sentence using subword option""" + if isinstance(sentence, str): + sentence = sentence.encode("utf-8").split(b' ') + + if subword_option == "bpe": # BPE + sentence = format_bpe_text(sentence) + elif subword_option == "spm": # SPM + sentence = format_spm_text(sentence) + else: + sentence = format_text(sentence) + + return sentence diff --git a/nmt/utils/nmt_utils.py b/nmt/utils/nmt_utils.py index 72f71b5c2..f912f8e08 100644 --- a/nmt/utils/nmt_utils.py +++ b/nmt/utils/nmt_utils.py @@ -99,11 +99,4 @@ def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option): if tgt_eos and tgt_eos in output: output = output[:output.index(tgt_eos)] - if subword_option == "bpe": # BPE - translation = utils.format_bpe_text(output) - elif subword_option == "spm": # SPM - translation = utils.format_spm_text(output) - else: - translation = utils.format_text(output) - - return translation + return utils.format_sentence(output, subword_option)