Simplot is a chart-based question-answering system leveraging pre-trained models for extracting tables from charts and answering questions. The project is divided into multiple phases, including data preparation, model training, and inference.
Download the dataset from the following link:
ChartQA Dataset - HuggingFace Repository
dataset.py: Prepares the dataset and generates positive and negative PNG samples.
preprocess.py: Processes the dataset to create inputs for model training.
main.py: Handles training of Phase 1 (Teacher Model) and Phase 2 (Student Model).
inference.py: Extracts tables from charts and generates predictions (saved as prediction.csv).
QA.py: Performs question answering using the Gemini model (results saved in qa_results.csv).
Follow these steps to train the models:
-
Download the Dataset
Download the full dataset from the ChartQA repository. -
Set Up the Environment
Create a Python environment using the dependencies listed inrequirements.txt. -
Run Preprocessing
Execute the preprocessing script:
(python preprocess.py) -
Phase 1: Teacher Model Training
Train the teacher model using the following command:
(python main.py --phase 1) -
Phase 2: Student Model Training
Train the student model by loading the best Phase 1 model state:
(python main.py --phase 2 --state_path './state/phase_1_best_model.pth' --lr 1e-5)
The pre-trained models are available in the state/ folder. To perform inference:
-
Set Up the Environment
Create the environment usingrequirements.txt. -
Extract Tables from Charts
Run the following command to generate predictions:
(python inference.py ) Sample output can be found inresult/prediction.csv. -
Question Answering
Use the Gemini model for question answering:
(python QA.py --api_key 'your_api_key' --qa_type 'human') Results will be saved inresult/qa_results.csv. A sample is already provided in theresult/folder.
Simplot/ │ ├── dataset.py ├── preprocess.py ├── main.py ├── inference.py ├── QA.py ├── requirements.txt ├── state/ # Pre-trained model files │ ├── phase_2_best_model.pth │ ├── ... ├── result/ # Output files │ ├── prediction.csv # Sample table extraction results │ ├── qa_results.csv # Sample QA results └── data/ # Dataset folder
- Ensure you have a valid API key for Gemini when running the QA phase.
- Modify paths in the commands if your directory structure differs.