Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 115 additions & 25 deletions ace/ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os import path
import logging
import re
from . import sources, config
from .scrape import _validate_scrape
import multiprocessing as mp
Expand All @@ -23,8 +24,59 @@ def _process_file(f):
return f, html


def _process_file_with_source(args):
"""Helper function to read, validate, and identify source for a single file."""
f, source_configs = args
logger.info("Processing article %s..." % f)
try:
html = open(f).read()
except Exception as e:
logger.warning("Failed to read file %s: %s" % (f, str(e)))
return f, None, None

if not _validate_scrape(html):
logger.warning("Invalid HTML for %s" % f)
return f, None, None

# Identify source from HTML using regex patterns
source_name = None
for name, identifiers in source_configs.items():
for patt in identifiers:
if re.search(patt, html):
logger.debug('Matched article to Source: %s' % name)
source_name = name
break
if source_name:
break

return f, html, source_name


def _parse_article(args):
"""Helper function to parse an article from HTML content."""
# Unpack arguments
f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs = args

try:
# Get the actual source object
if source_name:
source = manager.sources[source_name]
else:
# Fallback to original source identification
source = manager.identify_source(html)
if source is None:
logger.warning("Could not identify source for %s" % f)
return f, None

article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs)
return f, article
except Exception as e:
logger.warning("Error parsing article %s: %s" % (f, str(e)))
return f, None


def add_articles(db, files, commit=True, table_dir=None, limit=None,
pmid_filenames=False, metadata_dir=None, force_ingest=True, parallel=True, num_workers=None, **kwargs):
pmid_filenames=False, metadata_dir=None, force_ingest=True, num_workers=None, **kwargs):
''' Process articles and add their data to the DB.
Args:
files: The path to the article(s) to process. Can be a single
Expand All @@ -46,13 +98,15 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None,
and will save the result of the query if it doesn't already
exist.
force_ingest: Ingest even if no source is identified.
parallel: Whether to process articles in parallel (default: True).
num_workers: Number of worker processes to use when processing in parallel.
If None (default), uses the number of CPUs available on the system.
kwargs: Additional keyword arguments to pass to parse_article.
'''

manager = sources.SourceManager(db, table_dir)
manager = sources.SourceManager(table_dir)

# Prepare source configurations for parallel processing
source_configs = {name: source.identifiers for name, source in manager.sources.items()}

if isinstance(files, str):
from glob import glob
Expand All @@ -64,38 +118,74 @@ def add_articles(db, files, commit=True, table_dir=None, limit=None,

missing_sources = []

if parallel:
# Process files in parallel to extract HTML content
# Step 1: Process files in parallel to extract HTML content and identify sources
if num_workers is not None and num_workers != 1:
# Process files in parallel to extract HTML content and identify sources
process_args = [(f, source_configs) for f in files]
with mp.Pool(processes=num_workers) as pool:
file_html_pairs = pool.map(_process_file, files)
file_html_source_tuples = pool.map(_process_file_with_source, process_args)
else:
# Process files sequentially
file_html_pairs = []
file_html_source_tuples = []
for f in files:
file_html_pairs.append(_process_file(f))
result = _process_file_with_source((f, source_configs))
file_html_source_tuples.append(result)

# Step 2: In serial mode, use the db object to skip articles that have been already added
# Filter out files with reading/validation errors
valid_files = []
for f, html, source_name in file_html_source_tuples:
if html is not None:
valid_files.append((f, html, source_name))
# We'll handle missing sources later when we actually parse the articles

# Filter out articles that already exist in the database
files_to_process = []
missing_sources = []

# Process each file's HTML content
for i, (f, html) in enumerate(file_html_pairs):
if html is None:
# File reading or validation failed
missing_sources.append(f)
for f, html, source_name in valid_files:
pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None

# Check if article already exists
if pmid is not None and db.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS:
continue

files_to_process.append((f, html, source_name, pmid))

source = manager.identify_source(html)
if source is None:
logger.warning("Could not identify source for %s" % f)
missing_sources.append(f)
if not force_ingest:
continue
else:
source = sources.DefaultSource(db)
# Step 3: Process remaining articles in parallel
# Prepare arguments for _parse_article
parse_args = [(f, html, source_name, pmid, manager, metadata_dir, force_ingest, kwargs)
for f, html, source_name, pmid in files_to_process]

pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None
article = source.parse_article(html, pmid, metadata_dir=metadata_dir, **kwargs)
if article and (config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables):
if num_workers is not None and num_workers != 1 and parse_args:
# Parse articles in parallel
with mp.Pool(processes=num_workers) as pool:
parsed_articles = pool.map(_parse_article, parse_args)
else:
# Parse articles sequentially
parsed_articles = []
for args in parse_args:
parsed_articles.append(_parse_article(args))

# Add successfully parsed articles to database
for i, (f, article) in enumerate(parsed_articles):
if article is None:
missing_sources.append(f)
continue

if config.SAVE_ARTICLES_WITHOUT_ACTIVATIONS or article.tables:
# Check again if article exists and handle overwrite
pmid = path.splitext(path.basename(f))[0] if pmid_filenames else None
if pmid is not None and db.article_exists(pmid):
if config.OVERWRITE_EXISTING_ROWS:
db.delete_article(pmid)
else:
continue

db.add(article)
if commit and (i % 100 == 0 or i == len(file_html_pairs) - 1):
if commit and (i % 100 == 0 or i == len(parsed_articles) - 1):
db.save()

db.save()

return missing_sources
27 changes: 9 additions & 18 deletions ace/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ class SourceManager:
associated directory of JSON config files and uses them to determine which parser
to call when a new HTML file is passed. '''

def __init__(self, database, table_dir=None):
def __init__(self, table_dir=None):
''' SourceManager constructor.
Args:
database: A Database instance to use with all Sources.
table_dir: An optional directory name to save any downloaded tables to.
When table_dir is None, nothing will be saved (requiring new scraping
each time the article is processed).
Expand All @@ -47,7 +46,7 @@ def __init__(self, database, table_dir=None):
source_dir = os.path.join(os.path.dirname(__file__), 'sources')
for config_file in glob('%s/*json' % source_dir):
class_name = config_file.split('/')[-1].split('.')[0]
cls = getattr(module, class_name + 'Source')(database, config=config_file, table_dir=table_dir)
cls = getattr(module, class_name + 'Source')(config=config_file, table_dir=table_dir)
self.sources[class_name] = cls

def identify_source(self, html):
Expand Down Expand Up @@ -161,8 +160,7 @@ def _safe_clean_html(self, html):
text_parts.append(text.strip())
return '\n\n'.join(text_parts) if text_parts else soup.get_text()

def __init__(self, database, config=None, table_dir=None):
self.database = database
def __init__(self, config=None, table_dir=None):
self.table_dir = table_dir
self.entities = {}

Expand All @@ -181,16 +179,11 @@ def __init__(self, database, config=None, table_dir=None):
else:
self.entities.update(Source.ENTITIES)

@abc.abstractmethod
def parse_article(self, html, pmid=None, metadata_dir=None):
''' Takes HTML article as input and returns an Article. PMID Can also be
passed, which prevents having to scrape it from the article and/or look it
''' Takes HTML article as input and returns an Article. PMID Can also be
passed, which prevents having to scrape it from the article and/or look it
up in PubMed. '''

# Skip rest of processing if this record already exists
if pmid is not None and self.database.article_exists(pmid) and not config.OVERWRITE_EXISTING_ROWS:
return False


html = self.decode_html_entities(html)
soup = BeautifulSoup(html, "lxml")
if pmid is None:
Expand All @@ -208,11 +201,6 @@ def parse_article(self, html, pmid=None, metadata_dir=None):

# Get text using readability
text = self._clean_html_with_readability(str(soup))
if self.database.article_exists(pmid):
if config.OVERWRITE_EXISTING_ROWS:
self.database.delete_article(pmid)
else:
return False

self.article = database.Article(text, pmid=pmid, metadata=metadata)
self.extract_neurovault(soup)
Expand Down Expand Up @@ -401,6 +389,9 @@ class DefaultSource(Source):
3. JavaScript expansion detection: Identifies elements that might trigger
table expansion via JavaScript (logging only, not implemented)
"""
def __init__(self, config=None, table_dir=None):
super().__init__(config=config, table_dir=table_dir)

def parse_article(self, html, pmid=None, **kwargs):
soup = super(DefaultSource, self).parse_article(html, pmid, **kwargs)
if not soup:
Expand Down
2 changes: 1 addition & 1 deletion ace/tests/test_ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def db():

@pytest.fixture(scope="module")
def source_manager(db):
return sources.SourceManager(db)
return sources.SourceManager()


@pytest.mark.vcr(record_mode="once")
Expand Down
Loading