From 5d68638637976ebc66bff82c71ab3446a9d60949 Mon Sep 17 00:00:00 2001 From: Hany Date: Thu, 18 Sep 2025 05:44:32 +0400 Subject: [PATCH] Decouple NER tag vocab from torchtext and add conversion CLI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ref: https://github.com/SinaLab/sinatools/issues/10 - add sinatools/ner/tag_vocab.py with shims that load legacy torchtext pickles and re‑serialize them into the native Vocab class (lazy-imported to avoid circular deps) - switch sinatools/ner/__init__.py and helpers.py to use the new loader so the runtime works without torchtext - expose a convert_tag_vocab console script (and register it in setup.py) to rewrite checkpoints in place for updated environments --- setup.py | 30 +-- sinatools/CLI/ner/convert_tag_vocab.py | 74 ++++++++ sinatools/ner/__init__.py | 8 +- sinatools/ner/helpers.py | 7 +- sinatools/ner/tag_vocab.py | 241 +++++++++++++++++++++++++ 5 files changed, 336 insertions(+), 24 deletions(-) create mode 100644 sinatools/CLI/ner/convert_tag_vocab.py create mode 100644 sinatools/ner/tag_vocab.py diff --git a/setup.py b/setup.py index 6daca9a..383ccb1 100644 --- a/setup.py +++ b/setup.py @@ -69,20 +69,22 @@ 'sinatools.CLI.DataDownload.get_appdatadir:main'), ('download_files=' 'sinatools.CLI.DataDownload.download_files:main'), - ('corpus_entity_extractor=' - 'sinatools.CLI.ner.corpus_entity_extractor:main'), - ('text_dublication_detector=' - 'sinatools.CLI.utils.text_dublication_detector:main'), - ('evaluate_synonyms=' - 'sinatools.CLI.synonyms.evaluate_synonyms:main'), - ('extend_synonyms=' - 'sinatools.CLI.synonyms.extend_synonyms:main'), - ('semantic_relatedness=' - 'sinatools.CLI.semantic_relatedness.compute_relatedness:main'), - ('relation_extractor=' - 'sinatools.CLI.relations.relation_extractor:main'), - ], - }, + ('corpus_entity_extractor=' + 'sinatools.CLI.ner.corpus_entity_extractor:main'), + ('text_dublication_detector=' + 'sinatools.CLI.utils.text_dublication_detector:main'), + ('evaluate_synonyms=' + 'sinatools.CLI.synonyms.evaluate_synonyms:main'), + ('extend_synonyms=' + 'sinatools.CLI.synonyms.extend_synonyms:main'), + ('semantic_relatedness=' + 'sinatools.CLI.semantic_relatedness.compute_relatedness:main'), + ('relation_extractor=' + 'sinatools.CLI.relations.relation_extractor:main'), + ('convert_tag_vocab=' + 'sinatools.CLI.ner.convert_tag_vocab:main'), + ], + }, data_files=[('sinatools', ['sinatools/environment.yml'])], package_data={'sinatools': ['data/*.pickle', 'environment.yml']}, install_requires=requirements, diff --git a/sinatools/CLI/ner/convert_tag_vocab.py b/sinatools/CLI/ner/convert_tag_vocab.py new file mode 100644 index 0000000..c270c10 --- /dev/null +++ b/sinatools/CLI/ner/convert_tag_vocab.py @@ -0,0 +1,74 @@ +"""CLI helper to convert legacy torchtext tag vocabularies.""" +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Optional + +from sinatools.ner.tag_vocab import convert_tag_vocab_file + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Convert a legacy tag_vocab.pkl to the native SinaTools format", + ) + parser.add_argument( + "model_dir", + type=Path, + help="Directory that contains tag_vocab.pkl (e.g. models/sinatools/Wj27012000.tar)", + ) + parser.add_argument( + "--backup-suffix", + default=".legacy", + help="Suffix used for the backup copy (default: .legacy)", + ) + parser.add_argument( + "--no-backup", + action="store_true", + help="Skip creating a backup copy before rewriting tag_vocab.pkl", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Only report whether conversion is needed without writing files", + ) + return parser + + +def main(argv: Optional[list[str]] = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + + model_dir: Path = args.model_dir + pickle_path = model_dir / "tag_vocab.pkl" + + if not pickle_path.exists(): + parser.error(f"No tag_vocab.pkl found at {pickle_path}") + + backup_suffix = None if args.no_backup else args.backup_suffix + + try: + converted, backup_path = convert_tag_vocab_file( + model_dir, + backup_suffix=backup_suffix, + dry_run=args.dry_run, + ) + except Exception as exc: # pragma: no cover - CLI surface + parser.error(str(exc)) + + if args.dry_run: + if converted: + parser.exit(0, "Conversion required (dry-run).\n") + parser.exit(0, "No conversion required.\n") + + if not converted: + parser.exit(0, "tag_vocab.pkl already uses the native format.\n") + + message = "Converted tag_vocab.pkl" if backup_suffix is None else ( + f"Converted tag_vocab.pkl (backup saved to {backup_path})" + ) + parser.exit(0, message + "\n") + + +if __name__ == "__main__": # pragma: no cover - CLI entry point + raise SystemExit(main()) diff --git a/sinatools/ner/__init__.py b/sinatools/ner/__init__.py index f701470..1c629dd 100644 --- a/sinatools/ner/__init__.py +++ b/sinatools/ner/__init__.py @@ -1,10 +1,8 @@ from sinatools.DataDownload import downloader import os from sinatools.ner.helpers import load_object -import pickle -import os +from sinatools.ner.tag_vocab import load_tag_vocab import torch -import pickle import json from argparse import Namespace @@ -17,9 +15,7 @@ model_path = os.path.join(path, filename) _path = os.path.join(model_path, "tag_vocab.pkl") - -with open(_path, "rb") as fh: - tag_vocab = pickle.load(fh) +tag_vocab = load_tag_vocab(_path) train_config = Namespace() args_path = os.path.join(model_path, "args.json") diff --git a/sinatools/ner/helpers.py b/sinatools/ner/helpers.py index ff625b2..d2aa1e1 100644 --- a/sinatools/ner/helpers.py +++ b/sinatools/ner/helpers.py @@ -4,11 +4,11 @@ import importlib import shutil import torch -import pickle import json import random import numpy as np from argparse import Namespace +from sinatools.ner.tag_vocab import load_tag_vocab def logging_config(log_file=None): @@ -70,8 +70,7 @@ def load_checkpoint(model_path): vocab - arabicner.utils.data.Vocab - indexed tags train_config - argparse.Namespace - training configurations """ - with open(os.path.join(model_path, "tag_vocab.pkl"), "rb") as fh: - tag_vocab = pickle.load(fh) + tag_vocab = load_tag_vocab(os.path.join(model_path, "tag_vocab.pkl")) # Load train configurations from checkpoint train_config = Namespace() @@ -114,4 +113,4 @@ def set_seed(seed): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - torch.backends.cudnn.enabled = False \ No newline at end of file + torch.backends.cudnn.enabled = False diff --git a/sinatools/ner/tag_vocab.py b/sinatools/ner/tag_vocab.py new file mode 100644 index 0000000..50207dc --- /dev/null +++ b/sinatools/ner/tag_vocab.py @@ -0,0 +1,241 @@ +"""Utilities for loading and converting legacy torchtext tag vocabularies.""" +from __future__ import annotations + +import pickle +import sys +import types +from collections import Counter +from pathlib import Path +from typing import Iterable, List, Sequence, Tuple, Union + +if sys.version_info >= (3, 8): # pragma: no branch - typing guard + from typing import TYPE_CHECKING +else: # pragma: no cover - Py<3.8 fallback + TYPE_CHECKING = False + +if TYPE_CHECKING: # pragma: no cover - type checking only + from sinatools.ner.data_format import Vocab + +LegacyVocabSequence = Sequence[object] + + +def load_tag_vocab( + path: Union[str, Path], + *, + convert_legacy: bool = True, + return_metadata: bool = False, +) -> Union[List["Vocab"], Tuple[List["Vocab"], dict]]: + """Load ``tag_vocab.pkl`` while remaining robust to legacy torchtext pickles. + + Args: + path: Filesystem path to ``tag_vocab.pkl``. + convert_legacy: When ``True`` (default) legacy torchtext objects are translated + to SinaTools' lightweight :class:`~sinatools.ner.data_format.Vocab`. + + Returns: + List of vocabulary objects usable by the inference pipeline. + """ + + path = Path(path) + + converted = False + + try: + with path.open("rb") as fh: + tag_vocab = pickle.load(fh) + except (ModuleNotFoundError, AttributeError) as exc: + if not _mentions_torchtext(exc): + raise + legacy_vocab = _load_with_shims(path) + if convert_legacy: + tag_vocab = _convert_legacy_vocab(legacy_vocab) + converted = True + else: + tag_vocab = legacy_vocab + else: + if convert_legacy and _looks_like_legacy(tag_vocab): + tag_vocab = _convert_legacy_vocab(tag_vocab) + converted = True + + if return_metadata: + return tag_vocab, {"converted": converted} + + return tag_vocab + + +def convert_tag_vocab_file( + model_dir: Union[str, Path], + *, + backup_suffix: str = ".legacy", + dry_run: bool = False, +) -> Tuple[bool, Path | None]: + """Convert ``tag_vocab.pkl`` in-place to the internal :class:`Vocab` format. + + Args: + model_dir: Directory containing ``tag_vocab.pkl``. + backup_suffix: Suffix used when stashing the original pickle. Set to ``None`` + to skip creating a backup copy. + dry_run: When ``True`` only report whether conversion is required. + + Returns: + Tuple of ``(converted, backup_path)`` where ``converted`` indicates whether a + rewrite happened (or would happen for ``dry_run``). ``backup_path`` is the path + to the backup file when one is created, otherwise ``None``. + """ + + model_dir = Path(model_dir) + pickle_path = model_dir / "tag_vocab.pkl" + + tag_vocab, metadata = load_tag_vocab( + pickle_path, convert_legacy=True, return_metadata=True + ) + + needs_conversion = metadata["converted"] + + if dry_run or not needs_conversion: + return needs_conversion, None + + backup_path: Path | None = None + + if backup_suffix is not None: + backup_path = pickle_path.with_suffix(pickle_path.suffix + backup_suffix) + if backup_path.exists(): + raise FileExistsError(f"Backup file already exists: {backup_path}") + pickle_path.replace(backup_path) + + with pickle_path.open("wb") as fh: + pickle.dump(tag_vocab, fh) + + return True, backup_path + + +def _mentions_torchtext(exc: BaseException) -> bool: + message = str(exc) + return "torchtext" in message + + +def _load_with_shims(path: Path) -> LegacyVocabSequence: + created_modules = _install_shims() + + try: + with path.open("rb") as fh: + return pickle.load(fh) + finally: + _remove_shims(created_modules) + + +def _install_shims() -> Tuple[str, ...]: + created = [] + + if "torchtext" not in sys.modules: + torchtext_module = types.ModuleType("torchtext") + sys.modules["torchtext"] = torchtext_module + created.append("torchtext") + else: + torchtext_module = sys.modules["torchtext"] + + if not hasattr(torchtext_module, "vocab"): + vocab_module = types.ModuleType("torchtext.vocab") + torchtext_module.vocab = vocab_module + sys.modules["torchtext.vocab"] = vocab_module + created.append("torchtext.vocab") + else: + vocab_module = torchtext_module.vocab + + if not hasattr(vocab_module, "vocab"): + vocab_vocab_module = types.ModuleType("torchtext.vocab.vocab") + vocab_module.vocab = vocab_vocab_module + sys.modules["torchtext.vocab.vocab"] = vocab_vocab_module + created.append("torchtext.vocab.vocab") + else: + vocab_vocab_module = vocab_module.vocab + + if "torchtext._torchtext" not in sys.modules: + backend_module = types.ModuleType("torchtext._torchtext") + sys.modules["torchtext._torchtext"] = backend_module + created.append("torchtext._torchtext") + else: + backend_module = sys.modules["torchtext._torchtext"] + + vocab_class = _TorchtextVocabShim + vocab_vocab_module.Vocab = vocab_class + backend_module.Vocab = vocab_class + + return tuple(created) + + +def _remove_shims(created_modules: Tuple[str, ...]) -> None: + for name in created_modules: + sys.modules.pop(name, None) + + +def _looks_like_legacy(obj: object) -> bool: + if isinstance(obj, (list, tuple)): + return any(_looks_like_legacy(item) for item in obj) + module = getattr(obj.__class__, "__module__", "") + if module.startswith("torchtext"): + return True + if hasattr(obj, "vocab"): + return _looks_like_legacy(getattr(obj, "vocab")) + return False + + +def _convert_legacy_vocab(raw_vocab: LegacyVocabSequence) -> List[Vocab]: + VocabCls = _get_vocab_class() + converted: List["Vocab"] = [] + for entry in _ensure_sequence(raw_vocab): + tokens = _extract_tokens(entry) + counter = Counter(tokens) + converted.append(VocabCls(counter)) + return converted + + +def _extract_tokens(entry: object) -> List[str]: + if hasattr(entry, "itos") and isinstance(getattr(entry, "itos"), Iterable): + return list(getattr(entry, "itos")) + + nested = getattr(entry, "vocab", None) + if nested is not None: + return _extract_tokens(nested) + + if hasattr(entry, "stoi") and isinstance(entry.stoi, dict): + # Sort by index to recover original ordering. + return [token for token, _ in sorted(entry.stoi.items(), key=lambda item: item[1])] + + raise TypeError("Unsupported legacy vocab structure") + + +def _ensure_sequence(obj: object) -> Sequence[object]: + if isinstance(obj, (list, tuple)): + return obj + raise TypeError("Expected sequence of vocabulary entries") + + +class _TorchtextVocabShim: + """Minimal object to satisfy torchtext pickle references.""" + + def __init__(self, *args, **kwargs): + pass + + def __setstate__(self, state): + # store raw state for debugging, but all meaningful attributes are captured + # via ``vocab`` or the tuple state handled in ``_extract_tokens``. + if isinstance(state, tuple) and len(state) == 4: + version, unk_tokens, itos, specials = state + self.version = version + self.unk_tokens = unk_tokens + self.itos = list(itos) + self.specials = specials + elif isinstance(state, dict): + self.__dict__.update(state) + else: + self.state = state + + def __getstate__(self): # pragma: no cover - compatibility only + return getattr(self, "state", {}) + + +def _get_vocab_class(): + from sinatools.ner.data_format import Vocab + + return Vocab