I build it for study deep learning model with pytorch
All code started with train.py, we use config file to differentiate the model we used.
Just like: python train -c model_config.json
For your owner useage:
- write your owner
dataSetunder./data_loaderand add it to./data_loader/data_loaders.py - write your owner
configfile under./configsto choose model and set parameters - python
train.py -c ./configs/config.json
For each task of your owner, you should build dataloader in ./data_loader and config the json file in ./configs
Sometimes you will get tensor type error between long\float\int, all you need is to change your dataset file __getitem__
For Factorization Machine:
- write
criteo_dataset.pyunder./data_loader - add it to
./data_loader/data_loaders.py
class CriteoDataLoader(BaseDataLoader):
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1):
self.data_dir = data_dir
self.dataset = CriteoDataset(self.data_dir)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)- write
config_fm.jsonunder./configs - run python
train.py -c ./configs/config.json
| Model | Reference |
|---|---|
| Factorization Machine | S Rendle, Factorization Machines, 2010. |
| Field-aware Factorization Machine | Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015. |
| DeepFM | H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017. |
| Wide&Deep | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. |
| Deep Cross Network | R Wang, et al. Deep & Cross Network for Ad Click Predictions, 2017. |
| xDeepFM | J Lian, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems, 2018. |
| Model | Reference |
|---|---|
| fastText | Bag of Tricks for Efficient Text Classification |
| TextCNN | Convolutional Neural Networks for Sentence Classification |
| ModelType | DataSet | Source |
|---|---|---|
| CTR Prediction | CriteoDataset | criteo |
| NLP Classify | ThucnewsDataset | THUCNews |
| Model | acc | loss |
|---|---|---|
| FM | 0.854 | 0.68 |
| FastText | 0.998 | 0.02 |
| TextCNN | 0.954 | 0.18 |
You can also see the tensorboard at localhost:6006 by running tensorboard --logdir='./saved/log/fm'
Pytorch template based on: pytorch-template
Rec based on:pytorch-fm