diff --git a/ace/ingest.py b/ace/ingest.py index 9f1cf7b..84a2483 100644 --- a/ace/ingest.py +++ b/ace/ingest.py @@ -1,5 +1,6 @@ from os import path import logging +import re from . import sources, config from .scrape import _validate_scrape import multiprocessing as mp @@ -23,8 +24,59 @@ def _process_file(f): return f, html +def _process_file_with_source(args): + """Helper function to read, validate, and identify source for a single file.""" + f, source_configs = args + logger.info("Processing article %s..." % f) + try: + html = open(f).read() + except Exception as e: + logger.warning("Failed to read file %s: %s" % (f, str(e))) + return f, None, None + + if not _validate_scrape(html): + logger.warning("Invalid HTML for %s" % f) + return f, None, None + + # Identify source from HTML using regex patterns + source_name = None + for name, identifiers in source_configs.items(): + for patt in identifiers: + if re.search(patt, html): + logger.debug('Matched article to Source: %s' % name) + source_name = name + break + if source_name: + break + + return f, html, source_name + + +def _parse_article(args): + """Helper function to parse an article from HTML content.""" + # Unpack arguments + f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs = args + + try: + # Get the actual source object + if source_name: + source = manager.sources[source_name] + else: + # Fallback to original source identification + source = manager.identify_source(html) + if source is None: + logger.warning("Could not identify source for %s" % f) + return f, None + + article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs) + return f, article + except Exception as e: + logger.warning("Error parsing article %s: %s" % (f, str(e))) + return f, None + + def add_articles(db, files, commit=True, table_dir=None, limit=None, - pmid_filenames=False, metadata_dir=None, force_ingest=True, parallel=True, num_workers=None, **kwargs): + pmid_filenames=False, metadata_dir=None, force_ingest=True, num_workers=None, **kwargs): ''' Process articles and add their data to the DB. Args: files: The path to the article(s) to process. Can be a single @@ -46,13 +98,15 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None, and will save the result of the query if it doesn't already exist. force_ingest: Ingest even if no source is identified. - parallel: Whether to process articles in parallel (default: True). num_workers: Number of worker processes to use when processing in parallel. If None (default), uses the number of CPUs available on the system. kwargs: Additional keyword arguments to pass to parse_article. ''' - manager = sources.SourceManager(db, table_dir) + manager = sources.SourceManager(table_dir) + + # Prepare source configurations for parallel processing + source_configs = {name: source.identifiers for name, source in manager.sources.items()} if isinstance(files, str): from glob import glob @@ -64,38 +118,74 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None, missing_sources = [] - if parallel: - # Process files in parallel to extract HTML content + # Step 1: Process files in parallel to extract HTML content and identify sources + if num_workers is not None and num_workers != 1: + # Process files in parallel to extract HTML content and identify sources + process_args = [(f, source_configs) for f in files] with mp.Pool(processes=num_workers) as pool: - file_html_pairs = pool.map(_process_file, files) + file_html_source_tuples = pool.map(_process_file_with_source, process_args) else: # Process files sequentially - file_html_pairs = [] + file_html_source_tuples = [] for f in files: - file_html_pairs.append(_process_file(f)) + result = _process_file_with_source((f, source_configs)) + file_html_source_tuples.append(result) + + # Step 2: In serial mode, use the db object to skip articles that have been already added + # Filter out files with reading/validation errors + valid_files = [] + for f, html, source_name in file_html_source_tuples: + if html is not None: + valid_files.append((f, html, source_name)) + # We'll handle missing sources later when we actually parse the articles + + # Filter out articles that already exist in the database + files_to_process = [] + missing_sources = [] - # Process each file's HTML content - for i, (f, html) in enumerate(file_html_pairs): - if html is None: - # File reading or validation failed - missing_sources.append(f) + for f, html, source_name in valid_files: + pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None + + # Check if article already exists + if pmid is not None and db.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS: continue + + files_to_process.append((f, html, source_name, pmid)) - source = manager.identify_source(html) - if source is None: - logger.warning("Could not identify source for %s" % f) - missing_sources.append(f) - if not force_ingest: - continue - else: - source = sources.DefaultSource(db) + # Step 3: Process remaining articles in parallel + # Prepare arguments for _parse_article + parse_args = [(f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs) + for f, html, source_name, pmid in files_to_process] - pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None - article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs) - if article and (config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables): + if num_workers is not None and num_workers != 1 and parse_args: + # Parse articles in parallel + with mp.Pool(processes=num_workers) as pool: + parsed_articles = pool.map(_parse_article, parse_args) + else: + # Parse articles sequentially + parsed_articles = [] + for args in parse_args: + parsed_articles.append(_parse_article(args)) + + # Add successfully parsed articles to database + for i, (f, article) in enumerate(parsed_articles): + if article is None: + missing_sources.append(f) + continue + + if config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables: + # Check again if article exists and handle overwrite + pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None + if pmid is not None and db.article_exists(pmid): + if config.OVERWRITE_EXISTING_ROWS: + db.delete_article(pmid) + else: + continue + db.add(article) - if commit and (i % 100 == 0 or i == len(file_html_pairs) - 1): + if commit and (i % 100 == 0 or i == len(parsed_articles) - 1): db.save() + db.save() return missing_sources diff --git a/ace/sources.py b/ace/sources.py index 9155c6d..8edd79e 100644 --- a/ace/sources.py +++ b/ace/sources.py @@ -34,10 +34,9 @@ class SourceManager: associated directory of JSON config files and uses them to determine which parser to call when a new HTML file is passed. ''' - def __init__(self, database, table_dir=None): + def __init__(self, table_dir=None): ''' SourceManager constructor. Args: - database: A Database instance to use with all Sources. table_dir: An optional directory name to save any downloaded tables to. When table_dir is None, nothing will be saved (requiring new scraping each time the article is processed). @@ -47,7 +46,7 @@ def __init__(self, database, table_dir=None): source_dir = os.path.join(os.path.dirname(__file__), 'sources') for config_file in glob('%s/*json' % source_dir): class_name = config_file.split('/')[-1].split('.')[0] - cls = getattr(module, class_name + 'Source')(database, config=config_file, table_dir=table_dir) + cls = getattr(module, class_name + 'Source')(config=config_file, table_dir=table_dir) self.sources[class_name] = cls def identify_source(self, html): @@ -161,8 +160,7 @@ def _safe_clean_html(self, html): text_parts.append(text.strip()) return '\n\n'.join(text_parts) if text_parts else soup.get_text() - def __init__(self, database, config=None, table_dir=None): - self.database = database + def __init__(self, config=None, table_dir=None): self.table_dir = table_dir self.entities = {} @@ -181,16 +179,11 @@ def __init__(self, database, config=None, table_dir=None): else: self.entities.update(Source.ENTITIES) - @abc.abstractmethod def parse_article(self, html, pmid=None, metadata_dir=None): - ''' Takes HTML article as input and returns an Article. PMID Can also be - passed, which prevents having to scrape it from the article and/or look it + ''' Takes HTML article as input and returns an Article. PMID Can also be + passed, which prevents having to scrape it from the article and/or look it up in PubMed. ''' - - # Skip rest of processing if this record already exists - if pmid is not None and self.database.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS: - return False - + html = self.decode_html_entities(html) soup = BeautifulSoup(html, "lxml") if pmid is None: @@ -208,11 +201,6 @@ def parse_article(self, html, pmid=None, metadata_dir=None): # Get text using readability text = self._clean_html_with_readability(str(soup)) - if self.database.article_exists(pmid): - if config.OVERWRITE_EXISTING_ROWS: - self.database.delete_article(pmid) - else: - return False self.article = database.Article(text, pmid=pmid, metadata=metadata) self.extract_neurovault(soup) @@ -401,6 +389,9 @@ class DefaultSource(Source): 3. JavaScript expansion detection: Identifies elements that might trigger table expansion via JavaScript (logging only, not implemented) """ + def __init__(self, config=None, table_dir=None): + super().__init__(config=config, table_dir=table_dir) + def parse_article(self, html, pmid=None, **kwargs): soup = super(DefaultSource, self).parse_article(html, pmid, **kwargs) if not soup: diff --git a/ace/tests/test_ace.py b/ace/tests/test_ace.py index 2c4719a..add2149 100644 --- a/ace/tests/test_ace.py +++ b/ace/tests/test_ace.py @@ -28,7 +28,7 @@ def db(): @pytest.fixture(scope="module") def source_manager(db): - return sources.SourceManager(db) + return sources.SourceManager() @pytest.mark.vcr(record_mode="once")