From edc0d2c7cb6c46b1465b98753d97f22e91774dcb Mon Sep 17 00:00:00 2001 From: pltrdy Date: Tue, 6 Oct 2020 18:23:39 +0200 Subject: [PATCH] experimental exclusion --- bin/rouge_cmd.py | 12 ++++--- rouge/rouge.py | 83 +++++++++++++++++++++++++++++++++++++----------- 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/bin/rouge_cmd.py b/bin/rouge_cmd.py index 435e2b3..df29e64 100755 --- a/bin/rouge_cmd.py +++ b/bin/rouge_cmd.py @@ -19,6 +19,11 @@ def main(): help="Ignore empty hypothesis") parser.add_argument('hypothesis', type=str, help='Text of file path') parser.add_argument('reference', type=str, help='Text or file path') + parser.add_argument( + '-exclude', '-x', + type=str, + default=None, + help='Text or file path') parser.add_argument("--metrics", nargs="+", type=str.upper, choices=METRICS_CHOICES.keys(), help="Metrics to use (default=all)") @@ -36,12 +41,11 @@ def main(): if args.file: hyp, ref = args.hypothesis, args.reference - assert(os.path.isfile(hyp)) - assert(os.path.isfile(ref)) files_rouge = FilesRouge(metrics, stats) scores = files_rouge.get_scores( - hyp, ref, avg=args.avg, ignore_empty=args.ignore_empty) + hyp, ref, exclude_path=args.exclude, + avg=args.avg, ignore_empty=args.ignore_empty) print(json.dumps(scores, indent=2)) else: @@ -50,7 +54,7 @@ def main(): assert(isinstance(ref, str)) rouge = Rouge(metrics, stats) - scores = rouge.get_scores(hyp, ref, avg=args.avg) + scores = rouge.get_scores(hyp, ref, exclude=exclude, avg=args.avg) print(json.dumps(scores, indent=2)) diff --git a/rouge/rouge.py b/rouge/rouge.py index e3a53dc..6d5e40b 100644 --- a/rouge/rouge.py +++ b/rouge/rouge.py @@ -6,16 +6,39 @@ import os +def exclude_ngrams(sequence, n, exclude_ngrams): + from nltk import ngrams + sg = ngrams(sequence.split(), n) + for g in sg: + if g in exclude_ngrams: + sequence = sequence.replace(" ".join(g), "").strip() + return sequence + + +def exclude_sequences(hyps, refs, excludes, max_n=10, min_n=2): + assert len(hyps) == len(refs) + assert len(excludes) == len(hyps) + assert max_n > min_n + + xhyps = [] + xrefs = [] + for i, (h, r, e) in enumerate(zip(hyps, refs, excludes)): + for n in range(max_n, min_n, -1): + eg = list(ngrams(e.split(), n)) + h = exclude_ngrams(h, n, eg) + r = exclude_ngrams(r, n, eg) + xhyps.append(h) + xrefs.append(r) + return xhyps, xrefs + + class FilesRouge: def __init__(self, *args, **kwargs): """See the `Rouge` class for args """ self.rouge = Rouge(*args, **kwargs) - def _check_files(self, hyp_path, ref_path): - assert(os.path.isfile(hyp_path)) - assert(os.path.isfile(ref_path)) - + def _check_files(self, *paths): def line_count(path): count = 0 with open(path, "rb") as f: @@ -23,28 +46,36 @@ def line_count(path): count += 1 return count - hyp_lc = line_count(hyp_path) - ref_lc = line_count(ref_path) - assert(hyp_lc == ref_lc) + lc = None + for path in paths: + if path is None: + continue - def get_scores(self, hyp_path, ref_path, avg=False, ignore_empty=False): + assert os.path.isfile(path), "AssertionError on " + path + _lc = line_count(path) + if lc is None: + lc = _lc + else: + assert lc == _lc, "AssertionError on " + path + + def get_scores(self, hyp_path, ref_path, exclude_path=None, **kwargs): """Calculate ROUGE scores between each pair of lines (hyp_file[i], ref_file[i]). Args: * hyp_path: hypothesis file path * ref_path: references file path - * avg (False): whether to get an average scores or a list + * **kwargs: see `Rouge.get_scores` for doc """ - self._check_files(hyp_path, ref_path) - - with io.open(hyp_path, encoding="utf-8", mode="r") as hyp_file: - hyps = [line[:-1] for line in hyp_file] + self._check_files(hyp_path, ref_path, exclude_path) - with io.open(ref_path, encoding="utf-8", mode="r") as ref_file: - refs = [line[:-1] for line in ref_file] + def read(path): + with io.open(path, encoding="utf-8", mode="r") as f: + return [line.strip() for line in f] - return self.rouge.get_scores(hyps, refs, avg=avg, - ignore_empty=ignore_empty) + hyps = read(hyp_path) + refs = read(ref_path) + exclude = read(exclude_path) if exclude_path is not None else None + return self.rouge.get_scores(hyps, refs, exclude=exclude, **kwargs) class Rouge: @@ -88,7 +119,23 @@ def __init__(self, metrics=None, stats=None, return_lengths=False, else: self.stats = Rouge.DEFAULT_STATS - def get_scores(self, hyps, refs, avg=False, ignore_empty=False): + def get_scores(self, hyps, refs, avg=False, ignore_empty=False, + exclude=None): + + if exclude is not None: + print("(excluding)") + hyps, refs = exclude_sequences(hyps, refs, exclude) + ignore_empty = True + hyps = [ + h if h.replace(".", "").strip() != "" + else "" + for h in hyps + ] + refs = [ + r if r.replace(".", "").strip() != "" + else "" + for r in refs + ] if isinstance(hyps, six.string_types): hyps, refs = [hyps], [refs]