Semi-supervised learning for ADC property prediction - PyTorch Implementation
This repository contains our successful replication of the ADCNet model designed by the original authors. We have converted the entire architecture from TensorFlow to PyTorch framework, achieving superior performance compared to the original implementation.
- Framework Migration: Successfully converted the complete ADCNet architecture from TensorFlow to PyTorch
- Weight Conversion: Performed complex weight conversion from TensorFlow .h5 format to PyTorch .pth format for the FG-BERT encoder
- Performance Improvement: Achieved 89% test accuracy (+2% improvement over the original 87% reported by authors)
- Hyperparameter Consistency: Trained using identical hyperparameters as the original authors to ensure fair comparison
FG-BERT (Functional Group-BERT) is a self-supervised deep learning framework designed to enhance molecular representation learning by focusing on functional groups within molecules. Developed by idrugLab, FG-BERT leverages the Transformer architecture to pretrain on approximately 1.45 million unlabeled drug-like molecules, enabling the model to learn meaningful representations by masking and predicting functional groups, trained on 44 benchmark datasets One of the most significant technical challenges was converting the FG-BERT encoder weights from TensorFlow to PyTorch. The FG-BERT encoder is a critical component of ADCNet, specifically designed for molecular representation learning:
- Original Format: TensorFlow .h5 weights with specific layer naming conventions
- Target Format: PyTorch .pth state dictionaries with different tensor layouts
- Key Challenges:
- Tensor dimension reordering (TensorFlow uses different conventions than PyTorch)
- Layer naming scheme conversion
- Attention mechanism weight mapping
- Embedding layer parameter alignment
The FG-BERT encoder serves as the backbone for understanding molecular structures and relationships within the ADC (Antibody-Drug Conjugate) context. Our successful conversion ensures that all pre-trained knowledge from the original model is preserved while leveraging PyTorch's advantages.
We conducted an in-depth analysis of the ADCNet architecture to ensure accurate replication:
- Multi-Head Attention Mechanisms: Converted scaled dot-product attention with proper mask and adjacency matrix handling
- Encoder-Decoder Structure: Maintained the original transformer-based architecture
- Custom Activation Functions: Implemented GELU activation functions consistent with the original design
- Regularization Techniques: Preserved dropout patterns and layer normalization strategies
The primary limitation of this work is the constrained dataset size of only 435 samples. This small dataset size presents challenges for:
- Model generalization capability
- Statistical significance of improvements
- Comprehensive evaluation across diverse molecular structures
- Robust validation of the converted model
- Creating Larger Model
Despite this limitation, our 2% improvement demonstrates the effectiveness of the PyTorch implementation and suggests potential for further enhancement with larger datasets.
This work will be further enhanced in the following directions:
- Graph Neural Network Integration: Incorporate advanced GNN architectures to better capture molecular topology
- Attention Mechanism Refinement: Develop more sophisticated attention patterns for molecular interactions
- Multi-Scale Feature Learning: Implement hierarchical feature extraction for different molecular scales
- Ensemble Methods: Combine multiple model variants for improved prediction stability
- Advanced Regularization: Implement modern regularization techniques (e.g., dropout variants, batch normalization alternatives)
- Transfer Learning: Leverage larger pre-trained molecular models for enhanced representation learning
- Data Augmentation: Develop molecular-specific augmentation strategies
- Active Learning: Implement strategies to identify most informative samples for labeling
- Cross-Domain Transfer: Explore knowledge transfer from related molecular prediction tasks
py37.yaml contains the version specifications for various packages in the installed environment. The Embeddings folder contains antibody heavy and light chains, antigen macromolecule embeddings, Weights folder contains the weights of FG-BERT in pytorch, classification_weights folder contains model weights file.
conda create -n ADCNet python==3.10
pip install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install rdkit
pip install numpy
pip install pandas
pip install matplotlib
pip install hyperopt
pip install scikit-learn
pip install torch
pip install openpyxl
pip install fair-esmRefer to this link to download openbabel https://openbabel.org/docs/Installation/install.html#compile-language-bindings
- Run ESM-2.py to obtain embeddings for antibody heavy chain, light chain, and antigen
- Ensure each data entry contains the DAR value
- Create a folder named "medium3_weights" and place "bert_weightsMedium_20.h5" into it
conda activate ADCNet
python class.py- Run ESM-2.py to obtain embeddings for antibody heavy chain, light chain, and antigen
- Ensure each data entry contains the DAR value
- Create a folder named "classification_weights" and place "ADC_9.h5" into it
For reproducing results, run class.py directly.
conda activate ADCNet
python inference.pyIf you use this PyTorch implementation in your research, please cite both the original ADCNet paper and acknowledge this implementation work.