diff --git a/.gitignore b/.gitignore index e58a62b6..48a44234 100644 --- a/.gitignore +++ b/.gitignore @@ -109,4 +109,8 @@ venv.bak/ /site # mypy -.mypy_cache/ \ No newline at end of file +.mypy_cache/ + +# wandb logs +wandb/ +**/.watchman-cookie-* diff --git a/README.md b/README.md index 1f557a02..9d1c179c 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ To install, run `pip install -r requirements.txt`. In case of CUDA problems, con ### General Text Generation +- **FixemGAN** - [FixemGAN: Continious Space Text GAN on Fixed Embeddings](https://medium.com/@salaxieb.ildar/text-gan-on-embeddings-debb9a006fff) - **SeqGAN** - [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/abs/1609.05473) - **LeakGAN** - [Long Text Generation via Adversarial Training with Leaked Information](https://arxiv.org/abs/1709.08624) - **MaliGAN** - [Maximum-Likelihood Augmented Discrete Generative Adversarial Networks](https://arxiv.org/abs/1702.07983) @@ -67,9 +68,20 @@ To install, run `pip install -r requirements.txt`. In case of CUDA problems, con git clone https://github.com/williamSYSU/TextGAN-PyTorch.git cd TextGAN-PyTorch ``` +- Downlaod dataset and pretrained embeddings from kaggle dataset or manually: +```bash +kaggle datasets download -d salaxieb/texts-corpus-preprocessed +kaggle datasets download -d salaxieb/pretrained-embeddings + +unzip ../texts-corpus-preprocessed.zip -d dataset +mkdir dataset/testdata +mv dataset/*_test.txt dataset/testdata/ -- For real data experiments, all datasets (`Image COCO`, `EMNLP NEWs`, `Movie Review`, `Amazon Review`) can be downloaded from [here](https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing). -- Run with a specific model +mkdir pretrain +unzip ../pretrained-embeddings.zip -d pretrain/real_data +``` + +- Manually (`Image COCO`, `EMNLP NEWs`, `Movie Review`, `Amazon Review`) can be downloaded from [here](https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing). ```bash cd run @@ -86,13 +98,13 @@ python3 run_seqgan.py 0 0 For each model, the entire runing process is defined in `instructor/oracle_data/seqgan_instructor.py`. (Take SeqGAN in Synthetic data experiment for example). Some basic functions like `init_model()`and `optimize()` are defined in the base class `BasicInstructor` in `instructor.py`. If you want to add a new GAN-based text generation model, please create a new instructor under `instructor/oracle_data` and define the training process for the model. 2. **Visualization** - + Use `utils/visualization.py` to visualize the log file, including model loss and metrics scores. Custom your log files in `log_file_list`, no more than `len(color_list)`. The log filename should exclude `.txt`. - + 3. **Logging** The TextGAN-PyTorch use the `logging` module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved in `log/log_****_****.txt` and `save/**/log.txt` respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in `./save/**/models` and `./save/**/samples` per log step, where `**` depends on your hyper-parameters. - + 4. **Running Signal** You can easily control the training process with the class `Signal` (please refer to `utils/helpers.py`) based on dictionary file `run_signal.txt`. @@ -105,13 +117,25 @@ python3 run_seqgan.py 0 0 ## Implementation Details +### FixemGAN + +- run file: [run_fixem.py](run/run_fixem.py) + +- Instructors: [oracle_data](instructor/oracle_data/fixem_instructor.py), [real_data](instructor/real_data/fixem_instructor.py) + +- Models: [generator](models/generators/FixemGAN_G.py), [discriminator](models/discriminators/FixemGAN_D.py) + +- Structure (from [FixemGAM](https://www.com)) + + ![model_fixem](./assets/model_fixem.png) + ### SeqGAN - run file: [run_seqgan.py](run/run_seqgan.py) - Instructors: [oracle_data](instructor/oracle_data/seqgan_instructor.py), [real_data](instructor/real_data/seqgan_instructor.py) -- Models: [generator](models/SeqGAN_G.py), [discriminator](models/SeqGAN_D.py) +- Models: [generator](models/generators/SeqGAN_G.py), [discriminator](models/discriminators/SeqGAN_D.py) - Structure (from [SeqGAN](https://arxiv.org/pdf/1609.05473.pdf)) @@ -123,7 +147,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/leakgan_instructor.py), [real_data](instructor/real_data/leakgan_instructor.py) -- Models: [generator](models/LeakGAN_G.py), [discriminator](models/LeakGAN_D.py) +- Models: [generator](models/generators/LeakGAN_G.py), [discriminator](models/discriminators/LeakGAN_D.py) - Structure (from [LeakGAN](https://arxiv.org/pdf/1709.08624.pdf)) @@ -135,7 +159,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/maligan_instructor.py), [real_data](instructor/real_data/maligan_instructor.py) -- Models: [generator](models/MaliGAN_G.py), [discriminator](models/MaliGAN_D.py) +- Models: [generator](models/generators/MaliGAN_G.py), [discriminator](models/discriminators/MaliGAN_D.py) - Structure (from my understanding) @@ -147,7 +171,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/jsdgan_instructor.py), [real_data](instructor/real_data/jsdgan_instructor.py) -- Models: [generator](models/JSDGAN_G.py) (No discriminator) +- Models: [generator](models/generators/JSDGAN_G.py) (No discriminator) - Structure (from my understanding) @@ -159,31 +183,31 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/relgan_instructor.py), [real_data](instructor/real_data/relgan_instructor.py) -- Models: [generator](models/RelGAN_G.py), [discriminator](models/RelGAN_D.py) +- Models: [generator](models/generators/RelGAN_G.py), [discriminator](models/discriminators/RelGAN_D.py) - Structure (from my understanding) ![model_relgan](assets/model_relgan.png) - + ### DPGAN - run file: [run_dpgan.py](run/run_dpgan.py) - Instructors: [oracle_data](instructor/oracle_data/dpgan_instructor.py), [real_data](instructor/real_data/dpgan_instructor.py) -- Models: [generator](models/DPGAN_G.py), [discriminator](models/DPGAN_D.py) +- Models: [generator](models/generators/DPGAN_G.py), [discriminator](models/discriminators/DPGAN_D.py) - Structure (from [DPGAN](https://arxiv.org/abs/1802.01345)) ![model_dpgan](assets/model_dpgan.png) - + ### DGSAN - run file: [run_dgsan.py](run/run_dgsan.py) - Instructors: [oracle_data](instructor/oracle_data/dgsan_instructor.py), [real_data](instructor/real_data/dgsan_instructor.py) -- Models: [generator](models/DGSAN_G.py), [discriminator](models/DGSAN_D.py) +- Models: [generator](models/generators/DGSAN_G.py), [discriminator](models/discriminators/DGSAN_D.py) ### CoT @@ -191,7 +215,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/cot_instructor.py), [real_data](instructor/real_data/cot_instructor.py) -- Models: [generator](models/CoT_G.py), [discriminator](models/CoT_D.py) +- Models: [generator](models/generators/CoT_G.py), [discriminator](models/discriminators/CoT_D.py) - Structure (from [CoT](https://arxiv.org/abs/1804.03782)) @@ -203,7 +227,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/sentigan_instructor.py), [real_data](instructor/real_data/sentigan_instructor.py) -- Models: [generator](models/SentiGAN_G.py), [discriminator](models/SentiGAN_D.py) +- Models: [generator](models/generators/SentiGAN_G.py), [discriminator](models/discriminators/SentiGAN_D.py) - Structure (from [SentiGAN](https://www.ijcai.org/proceedings/2018/0618.pdf)) @@ -215,15 +239,14 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/catgan_instructor.py), [real_data](instructor/real_data/catgan_instructor.py) -- Models: [generator](models/CatGAN_G.py), [discriminator](models/CatGAN_D.py) +- Models: [generator](models/generators/CatGAN_G.py), [discriminator](models/discriminators/CatGAN_D.py) - Structure (from [CatGAN](https://arxiv.org/abs/1911.06641)) ![model_catgan](assets/model_catgan.png) - + ## Licence **MIT lincense** - diff --git a/assets/model_fixem.png b/assets/model_fixem.png new file mode 100644 index 00000000..9c3b6892 Binary files /dev/null and b/assets/model_fixem.png differ diff --git a/config.py b/config.py index 63b9d035..86052e5c 100644 --- a/config.py +++ b/config.py @@ -24,10 +24,10 @@ dis_pretrain = False clas_pretrain = False -run_model = 'catgan' # seqgan, leakgan, maligan, jsdgan, relgan, evogan, sentigan, catgan, dpgan, dgsan, cot +run_model = "catgan" # seqgan, leakgan, maligan, jsdgan, relgan, evogan, sentigan, catgan, dpgan, dgsan, cot k_label = 2 # num of labels, >=2 -gen_init = 'truncated_normal' # normal, uniform, truncated_normal -dis_init = 'uniform' # normal, uniform, truncated_normal +gen_init = "truncated_normal" # normal, uniform, truncated_normal +dis_init = "uniform" # normal, uniform, truncated_normal # ===CatGAN=== n_parent = 1 @@ -41,34 +41,55 @@ use_all_real_fake = False use_population = False +# ===FixemGAN=== +batches_per_epoch = 200 +noise_size = 1000 +max_epochs = 20 +target_len = 40 +real_fake_coeff = 1.0 +labels_coeff = 1.0 +diversity_coeff = 1.0 + +# ===Embedding=== +w2v_embedding_size = 100 +w2v_window = 5 +w2v_min_count = 1 +w2v_workers = 1 +w2v_samples_num = 5_000_000 + # ===Oracle or Real, type=== if_real_data = False # if use real data -dataset = 'oracle' # oracle, image_coco, emnlp_news, amazon_app_book, amazon_app_movie, mr15 -model_type = 'vanilla' # vanilla, RMC (custom) -loss_type = 'rsgan' # rsgan lsgan ragan vanilla wgan hinge, for Discriminator (CatGAN) -mu_type = 'ragan' # rsgan lsgan ragan vanilla wgan hinge -eval_type = 'Ra' # standard, rsgan, nll, nll-f1, Ra, bleu3, bleu-f1 -d_type = 'Ra' # S (Standard), Ra (Relativistic_average) -vocab_size = 5000 # oracle: 5000, coco: 4683, emnlp: 5256, amazon_app_book: 6418, mr15: 6289 +dataset = ( + "oracle" # oracle, image_coco, emnlp_news, amazon_app_book, amazon_app_movie, mr15 +) +model_type = "vanilla" # vanilla, RMC (custom) +loss_type = "rsgan" # rsgan lsgan ragan vanilla wgan hinge, for Discriminator (CatGAN) +mu_type = "ragan" # rsgan lsgan ragan vanilla wgan hinge +eval_type = "Ra" # standard, rsgan, nll, nll-f1, Ra, bleu3, bleu-f1 +d_type = "Ra" # S (Standard), Ra (Relativistic_average) +vocab_size = ( + 5000 # oracle: 5000, coco: 4683, emnlp: 5256, amazon_app_book: 6418, mr15: 6289 +) max_seq_len = 20 # oracle: 20, coco: 37, emnlp: 51, amazon_app_book: 40 ADV_train_epoch = 2000 # SeqGAN, LeakGAN-200, RelGAN-3000 extend_vocab_size = 0 # plus test data, only used for Classifier -temp_adpt = 'exp' # no, lin, exp, log, sigmoid, quad, sqrt -mu_temp = 'exp' # lin exp log sigmoid quad sqrt +temp_adpt = "exp" # no, lin, exp, log, sigmoid, quad, sqrt +mu_temp = "exp" # lin exp log sigmoid quad sqrt evo_temp_step = 1 temperature = 1 # ===Basic Train=== -samples_num = 10000 # 10000, mr15: 2000, +samples_num = 1000 # , mr15: 2000, +small_sample_num = 20 # used for self-blue MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 inter_epoch = 15 # LeakGAN-10 batch_size = 64 # 64 start_letter = 1 padding_idx = 0 -start_token = 'BOS' -padding_token = 'EOS' +start_token = "BOS" +padding_token = "EOS" gen_lr = 0.01 # 0.01 gen_adv_lr = 1e-4 # RelGAN-1e-4 dis_lr = 1e-4 # SeqGAN,LeakGAN-1e-2, RelGAN-1e-4 @@ -78,10 +99,10 @@ pre_log_step = 10 adv_log_step = 20 -train_data = 'dataset/' + dataset + '.txt' -test_data = 'dataset/testdata/' + dataset + '_test.txt' -cat_train_data = 'dataset/' + dataset + '_cat{}.txt' -cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' +train_data = "dataset/" + dataset + ".txt" +test_data = "dataset/testdata/" + dataset + "_test.txt" +cat_train_data = "dataset/" + dataset + "_cat{}.txt" +cat_test_data = "dataset/testdata/" + dataset + "_cat{}_test.txt" # ===Metrics=== use_nll_oracle = True @@ -89,6 +110,7 @@ use_nll_div = True use_bleu = True use_self_bleu = False +use_ioc = True use_clas_acc = True use_ppl = False @@ -103,6 +125,7 @@ mem_slots = 1 # RelGAN-1 num_heads = 2 # RelGAN-2 head_size = 256 # RelGAN-256 +generator_complexity = 512 # ===Discriminator=== d_step = 5 # SeqGAN-50, LeakGAN-5 @@ -113,25 +136,26 @@ dis_embed_dim = 64 dis_hidden_dim = 64 num_rep = 64 # RelGAN +discriminator_complexity = 512 # ===log=== log_time_str = strftime("%m%d_%H%M_%S", localtime()) log_filename = strftime("log/log_%s" % log_time_str) -if os.path.exists(log_filename + '.txt'): +if os.path.exists(log_filename + ".txt"): i = 2 while True: - if not os.path.exists(log_filename + '_%d' % i + '.txt'): - log_filename = log_filename + '_%d' % i + if not os.path.exists(log_filename + "_%d" % i + ".txt"): + log_filename = log_filename + "_%d" % i break i += 1 -log_filename = log_filename + '.txt' +log_filename = log_filename + ".txt" # Automatically choose GPU or CPU if torch.cuda.is_available() and torch.cuda.device_count() > 0: - os.system('nvidia-smi -q -d Utilization > gpu') - with open('gpu', 'r') as _tmpfile: - util_gpu = list(map(int, re.findall(r'Gpu\s+:\s*(\d+)\s*%', _tmpfile.read()))) - os.remove('gpu') + os.system("nvidia-smi -q -d Utilization > gpu") + with open("gpu", "r") as _tmpfile: + util_gpu = list(map(int, re.findall(r"Gpu\s+:\s*(\d+)\s*%", _tmpfile.read()))) + os.remove("gpu") if len(util_gpu): device = util_gpu.index(min(util_gpu)) else: @@ -142,63 +166,62 @@ # print('device: ', device) if multi_gpu: - devices = '0,1' - devices = list(map(int, devices.split(','))) + devices = "0,1" + devices = list(map(int, devices.split(","))) device = devices[0] torch.cuda.set_device(device) - os.environ['CUDA_VISIBLE_DIVICES'] = ','.join(map(str, devices)) + os.environ["CUDA_VISIBLE_DIVICES"] = ",".join(map(str, devices)) else: devices = str(device) torch.cuda.set_device(device) # ===Save Model and samples=== -save_root = 'save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/'.format(time.strftime("%Y%m%d"), - dataset, run_model, model_type, - d_type, - loss_type, - '+'.join( - [m[:2] for m in - mu_type.split()]), - eval_type, max_seq_len, - temperature, lambda_fd, - log_time_str) -save_samples_root = save_root + 'samples/' -save_model_root = save_root + 'models/' - -oracle_state_dict_path = 'pretrain/oracle_data/oracle_lstm.pt' -oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}.pt' -multi_oracle_state_dict_path = 'pretrain/oracle_data/oracle{}_lstm.pt' -multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}.pt' - -pretrain_root = 'pretrain/{}/'.format(dataset if if_real_data else 'oracle_data') -pretrained_gen_path = pretrain_root + 'gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) -pretrained_dis_path = pretrain_root + 'dis_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) -pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) -signal_file = 'run_signal.txt' - -tips = '' - -if samples_num == 5000 or samples_num == 2000: - assert 'c' in run_model, 'warning: samples_num={}, run_model={}'.format(samples_num, run_model) - +save_root = "save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/".format( + time.strftime("%Y%m%d"), + dataset, + run_model, + model_type, + d_type, + loss_type, + "+".join([m[:2] for m in mu_type.split()]), + eval_type, + max_seq_len, + temperature, + lambda_fd, + log_time_str, +) +save_samples_root = save_root + "samples/" +save_model_root = save_root + "models/" + +oracle_state_dict_path = "pretrain/oracle_data/oracle_lstm.pt" +oracle_samples_path = "pretrain/oracle_data/oracle_lstm_samples_{}.pt" +multi_oracle_state_dict_path = "pretrain/oracle_data/oracle{}_lstm.pt" +multi_oracle_samples_path = "pretrain/oracle_data/oracle{}_lstm_samples_{}.pt" + +pretrain_root = "pretrain/{}/".format(dataset if if_real_data else "oracle_data") +pretrained_gen_path = pretrain_root + "gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) +pretrained_dis_path = pretrain_root + "dis_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) +pretrained_clas_path = pretrain_root + "clas_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) + +embedding_root = "pretrain/real_data/" if if_real_data else "pretrain/oracle_data/" +pretrain_embedding_path = embedding_root + "w2v_embedding_size_{}.model".format( + w2v_embedding_size +) +texts_pile = "dataset/" # do not include testdata + +signal_file = "run_signal.txt" + +tips = "" # Init settings according to parser def init_param(opt): - global run_model, model_type, loss_type, CUDA, device, data_shuffle, samples_num, vocab_size, \ - MLE_train_epoch, ADV_train_epoch, inter_epoch, batch_size, max_seq_len, start_letter, padding_idx, \ - gen_lr, gen_adv_lr, dis_lr, clip_norm, pre_log_step, adv_log_step, train_data, test_data, temp_adpt, \ - temperature, oracle_pretrain, gen_pretrain, dis_pretrain, ADV_g_step, rollout_num, gen_embed_dim, \ - gen_hidden_dim, goal_size, step_size, mem_slots, num_heads, head_size, d_step, d_epoch, \ - ADV_d_step, ADV_d_epoch, dis_embed_dim, dis_hidden_dim, num_rep, log_filename, save_root, \ - signal_file, tips, save_samples_root, save_model_root, if_real_data, pretrained_gen_path, \ - pretrained_dis_path, pretrain_root, if_test, dataset, PRE_clas_epoch, oracle_samples_path, \ - pretrained_clas_path, n_parent, mu_type, eval_type, d_type, eval_b_num, lambda_fd, d_out_mean, \ - lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ - multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ - use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl + global run_model, model_type, loss_type, CUDA, device, data_shuffle, samples_num, vocab_size, MLE_train_epoch, ADV_train_epoch, inter_epoch, batch_size, max_seq_len, start_letter, padding_idx, gen_lr, gen_adv_lr, dis_lr, clip_norm, pre_log_step, adv_log_step, train_data, test_data, temp_adpt, temperature, oracle_pretrain, gen_pretrain, dis_pretrain, ADV_g_step, rollout_num, gen_embed_dim, gen_hidden_dim, goal_size, step_size, mem_slots, num_heads, head_size, d_step, d_epoch, ADV_d_step, ADV_d_epoch, dis_embed_dim, dis_hidden_dim, num_rep, log_filename, save_root, signal_file, tips, save_samples_root, save_model_root, if_real_data, pretrained_gen_path, pretrained_dis_path, pretrain_root, if_test, dataset, PRE_clas_epoch, oracle_samples_path, pretrained_clas_path, n_parent, mu_type, eval_type, d_type, eval_b_num, lambda_fd, d_out_mean, lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, real_fake_coeff, labels_coeff, diversity_coeff if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -227,6 +250,14 @@ def init_param(opt): use_all_real_fake = opt.use_all_real_fake use_population = opt.use_population + batches_per_epoch = opt.batches_per_epoch + noise_size = opt.noise_size + max_epochs = opt.max_epochs + target_len = opt.target_len + real_fake_coeff = opt.real_fake_coeff + labels_coeff = opt.labels_coeff + diversity_coeff = opt.diversity_coeff + samples_num = opt.samples_num vocab_size = opt.vocab_size MLE_train_epoch = opt.mle_epoch @@ -259,6 +290,7 @@ def init_param(opt): mem_slots = opt.mem_slots num_heads = opt.num_heads head_size = opt.head_size + generator_complexity = opt.generator_complexity d_step = opt.d_step d_epoch = opt.d_epoch @@ -267,6 +299,13 @@ def init_param(opt): dis_embed_dim = opt.dis_embed_dim dis_hidden_dim = opt.dis_hidden_dim num_rep = opt.num_rep + discriminator_complexity = opt.discriminator_complexity + + w2v_embedding_size = opt.w2v_embedding_size + w2v_window = opt.w2v_window + w2v_min_count = opt.w2v_min_count + w2v_workers = opt.w2v_workers + w2v_samples_num = opt.w2v_samples_num use_nll_oracle = True if opt.use_nll_oracle == 1 else False use_nll_gen = True if opt.use_nll_gen == 1 else False @@ -283,54 +322,85 @@ def init_param(opt): # CUDA device if multi_gpu: if type(devices) == str: - devices = list(map(int, devices.split(','))) + devices = list(map(int, devices.split(","))) device = devices[0] torch.cuda.set_device(device) - os.environ['CUDA_VISIBLE_DIVICES'] = ','.join(map(str, devices)) + os.environ["CUDA_VISIBLE_DIVICES"] = ",".join(map(str, devices)) else: devices = str(device) torch.cuda.set_device(device) # Save path - save_root = 'save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/'.format(time.strftime("%Y%m%d"), - dataset, run_model, model_type, - d_type, - loss_type, - '+'.join( - [m[:2] for m in - mu_type.split()]), - eval_type, max_seq_len, - temperature, lambda_fd, - log_time_str) - - save_samples_root = save_root + 'samples/' - save_model_root = save_root + 'models/' - - train_data = 'dataset/' + dataset + '.txt' - test_data = 'dataset/testdata/' + dataset + '_test.txt' - cat_train_data = 'dataset/' + dataset + '_cat{}.txt' - cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' + save_root = ( + "save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/".format( + time.strftime("%Y%m%d"), + dataset, + run_model, + model_type, + d_type, + loss_type, + "+".join([m[:2] for m in mu_type.split()]), + eval_type, + max_seq_len, + temperature, + lambda_fd, + log_time_str, + ) + ) + + save_samples_root = save_root + "samples/" + save_model_root = save_root + "models/" + + train_data = ( + "dataset/" + dataset + ".txt" + if if_real_data + else "pretrain/oracle_data/" + dataset + ".txt" + ) + test_data = "dataset/testdata/" + dataset + "_test.txt" + cat_train_data = ( + "dataset/" + dataset + "_cat{}.txt" + if if_real_data + else "pretrain/oracle_data/" + dataset + "_cat{}.txt" + ) + cat_test_data = "dataset/testdata/" + dataset + "_cat{}_test.txt" if max_seq_len == 40: - oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt' - multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt' - - pretrain_root = 'pretrain/{}/'.format(dataset if if_real_data else 'oracle_data') - pretrained_gen_path = pretrain_root + 'gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, - max_seq_len, samples_num) - pretrained_dis_path = pretrain_root + 'dis_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) - pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) - + oracle_samples_path = "pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt" + multi_oracle_samples_path = ( + "pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt" + ) + + pretrain_root = "pretrain/{}/".format(dataset if if_real_data else "oracle_data") + pretrained_gen_path = pretrain_root + "gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + pretrained_dis_path = pretrain_root + "dis_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + pretrained_clas_path = pretrain_root + "clas_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + embedding_root = "pretrain/real_data/" if if_real_data else "pretrain/oracle_data/" + pretrain_embedding_path = embedding_root + "w2v_embedding_size_{}.model".format( + w2v_embedding_size + ) # Assertion - assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) - assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( - eval_b_num, n_parent * ADV_d_step) + assert k_label >= 2, "Error: k_label = {}, which should be >=2!".format(k_label) + assert ( + eval_b_num >= n_parent * ADV_d_step + ), "Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!".format( + eval_b_num, n_parent * ADV_d_step + ) # Create Directory - dir_list = ['save', 'savefig', 'log', 'pretrain', 'dataset', - 'pretrain/{}'.format(dataset if if_real_data else 'oracle_data')] + dir_list = [ + "save", + "savefig", + "log", + "pretrain", + "dataset", + "pretrain/{}".format(dataset if if_real_data else "oracle_data"), + ] if not if_test: dir_list.extend([save_root, save_samples_root, save_model_root]) for d in dir_list: diff --git a/instructor/oracle_data/catgan_instructor.py b/instructor/oracle_data/catgan_instructor.py index cbdd75e2..2c960429 100644 --- a/instructor/oracle_data/catgan_instructor.py +++ b/instructor/oracle_data/catgan_instructor.py @@ -19,8 +19,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from metrics.nll import NLL -from models.CatGAN_D import CatGAN_D -from models.CatGAN_G import CatGAN_G +from models.discriminators.CatGAN_D import CatGAN_D +from models.generators.CatGAN_G import CatGAN_G from models.Oracle import Oracle from utils.cat_data_loader import CatGenDataIter from utils.data_loader import GenDataIter @@ -31,22 +31,57 @@ class CatGANInstructor(BasicInstructor): - def __init__(self, opt): super(CatGANInstructor, self).__init__(opt) # generator, discriminator - self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - - self.gen = CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, - gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = CatGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.oracle_list = [ + Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + + self.gen = CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = CatGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -54,17 +89,24 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) # DataLoader - self.all_oracle_data = CatGenDataIter(self.oracle_samples_list) # Shuffled all oracle data + self.all_oracle_data = CatGenDataIter( + self.oracle_samples_list + ) # Shuffled all oracle data def init_model(self): if cfg.oracle_pretrain: @@ -72,12 +114,20 @@ def init_model(self): oracle_path = cfg.multi_oracle_state_dict_path.format(i) if not os.path.exists(oracle_path): create_multi_oracle(cfg.k_label) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:%d' % cfg.device)) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:%d" % cfg.device) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -97,14 +147,24 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===Pre-train Generator=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===Adv-train=== progress = tqdm(range(cfg.ADV_train_epoch)) @@ -112,27 +172,45 @@ def _run(self): if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # ===Test=== - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.comb_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: for label_i in range(cfg.k_label): - self._save('ADV', adv_epoch, label_i) + self._save("ADV", adv_epoch, label_i) def _test(self): - self.log.debug('>>> Begin test...') + self.log.debug(">>> Begin test...") self._run() pass @@ -143,17 +221,20 @@ def pretrain_generator(self, epochs): """ for epoch in range(epochs): # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.all_oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.all_oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if not cfg.if_test and cfg.if_save: for label_i in range(cfg.k_label): - self._save('MLE', epoch, label_i) + self._save("MLE", epoch, label_i) def evolve_generator(self, evo_g_step): # evaluation real data @@ -169,13 +250,21 @@ def evolve_generator(self, evo_g_step): # all child share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -200,7 +289,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -224,17 +315,25 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -250,8 +349,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -303,14 +404,22 @@ def evolve_generator_population(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label # evaluate all parents - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): self.load_gen(parent, parent_opt) self.prepare_eval_fake_data() Fq, Fd, score = self.evaluation(cfg.eval_type) @@ -324,7 +433,9 @@ def evolve_generator_population(self, evo_g_step): # randomly choose a parent, variation target_idx = random.randint(0, len(self.parents) - 1) for j, criterionG in enumerate(self.G_criterion): - self.load_gen(self.parents[target_idx], self.parent_adv_opts[target_idx]) # load generator + self.load_gen( + self.parents[target_idx], self.parent_adv_opts[target_idx] + ) # load generator # Variation self.variation(evo_g_step, criterionG) @@ -340,7 +451,9 @@ def evolve_generator_population(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation.append(criterionG.loss_mode) @@ -353,15 +466,21 @@ def evolve_discriminator(self, evo_d_step): global dc_loss, dd_loss, d_loss total_loss = [] - all_gen_samples_list = list(map(self.merge, *self.best_fake_samples)) # merge each label of data - self.all_gen_samples_list = self.shuffle_eval_samples(all_gen_samples_list) # shuffle data + all_gen_samples_list = list( + map(self.merge, *self.best_fake_samples) + ) # merge each label of data + self.all_gen_samples_list = self.shuffle_eval_samples( + all_gen_samples_list + ) # shuffle data for step in range(evo_d_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('D', step) + dis_real_samples, dis_gen_samples = self.prepare_train_data("D", step) d_loss = 0 all_d_out_real = [] all_d_out_fake = [] - for (real_samples, fake_samples) in zip(dis_real_samples, dis_gen_samples): # for each label samples + for (real_samples, fake_samples) in zip( + dis_real_samples, dis_gen_samples + ): # for each label samples d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) d_loss += self.D_criterion(d_out_real, d_out_fake) @@ -386,14 +505,16 @@ def variation(self, g_step, criterionG): """Optimize one child (Generator)""" total_loss = [] for step in range(g_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('G') + dis_real_samples, dis_gen_samples = self.prepare_train_data("G") # ===Train=== g_loss = 0 all_d_out_real = [] all_d_out_fake = [] # for i, (real_samples, fake_samples) in enumerate(zip(dis_real_samples, dis_gen_samples)): - for i, (d_out_real, fake_samples) in enumerate(zip(self.d_out_real, dis_gen_samples)): # share real + for i, (d_out_real, fake_samples) in enumerate( + zip(self.d_out_real, dis_gen_samples) + ): # share real # d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) g_loss += criterionG(d_out_real, d_out_fake) @@ -416,50 +537,78 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) - for i in range(cfg.k_label)] + eval_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i + ) + for i in range(cfg.k_label) + ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) - nll_div.append(NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) - if 'f1' in eval_type: + nll_div.append( + NLL.cal_nll_with_label( + self.gen, gen_data.loader, label_i, self.mle_criterion + ) + ) + if "f1" in eval_type: if cfg.k_label == 1: Fd = nll_div[0] if len(nll_div) > 0 else 0 elif cfg.k_label == 2: - Fd = nll_div[0] * nll_div[1] / (nll_div[0] + nll_div[1]) if len(nll_div) > 0 else 0 + Fd = ( + nll_div[0] * nll_div[1] / (nll_div[0] + nll_div[1]) + if len(nll_div) > 0 + else 0 + ) else: - raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) + raise NotImplementedError( + "k_label = %d is not supported" % cfg.k_label + ) else: Fd = sum(nll_div) else: Fd = 0 # Fq - if 'nll' in eval_type: + if "nll" in eval_type: nll_oracle = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) if cfg.lambda_fq != 0: - nll_oracle.append(-NLL.cal_nll_with_label(self.oracle_list[label_i], gen_data.loader, label_i, - self.mle_criterion)) - - if 'f1' in eval_type: + nll_oracle.append( + -NLL.cal_nll_with_label( + self.oracle_list[label_i], + gen_data.loader, + label_i, + self.mle_criterion, + ) + ) + + if "f1" in eval_type: if cfg.k_label == 1: Fq = nll_oracle[0] if len(nll_oracle) > 0 else 0 elif cfg.k_label == 2: - Fq = nll_oracle[0] * nll_oracle[1] / (nll_oracle[0] + nll_oracle[1]) if len(nll_oracle) > 0 else 0 + Fq = ( + nll_oracle[0] * nll_oracle[1] / (nll_oracle[0] + nll_oracle[1]) + if len(nll_oracle) > 0 + else 0 + ) else: - raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) + raise NotImplementedError( + "k_label = %d is not supported" % cfg.k_label + ) else: # sum Fq = sum(nll_oracle) - elif eval_type == 'Ra': + elif eval_type == "Ra": g_loss = 0 for i in range(cfg.k_label): - g_loss += torch.sigmoid(self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() + g_loss += torch.sigmoid( + self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i]) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -470,7 +619,7 @@ def evaluation(self, eval_type): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target, label = data['input'], data['target'], data['label'] + inp, target, label = data["input"], data["target"], data["label"] if cfg.CUDA: inp, target, label = inp.cuda(), target.cuda(), label.cuda() @@ -483,8 +632,13 @@ def train_gen_epoch(self, model, data_loader, criterion, optimizer): def _save(self, phase, epoch, label_i=None): assert type(label_i) == int - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_c{}_{}_{:05d}.txt'.format(label_i, phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_c{}_{}_{:05d}.txt".format( + label_i, phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, label_i=label_i) write_tensor(save_sample_path, samples) @@ -495,50 +649,86 @@ def merge(*args): def shuffle_eval_samples(self, all_eval_samples): temp = [] for i in range(cfg.k_label): - temp.append(all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))]) + temp.append( + all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))] + ) return temp def prepare_train_data(self, which, step=None): """Prepare train data for both Generator and Discriminator, each samples_list contains k_label batches of data""" - assert which == 'D' or which == 'G', 'only support for D and G!!' + assert which == "D" or which == "G", "only support for D and G!!" real_samples_list = [ - F.one_hot(self.oracle_data_list[i].random_batch()['target'][:cfg.batch_size], - cfg.vocab_size).float().cuda() - for i in range(cfg.k_label)] - if which == 'D': - assert step is not None, 'missing step' - gen_samples_list = [self.all_gen_samples_list[i][step * cfg.batch_size:(step + 1) * cfg.batch_size] - for i in range(cfg.k_label)] # get a batch from each label + F.one_hot( + self.oracle_data_list[i].random_batch()["target"][: cfg.batch_size], + cfg.vocab_size, + ) + .float() + .cuda() + for i in range(cfg.k_label) + ] + if which == "D": + assert step is not None, "missing step" + gen_samples_list = [ + self.all_gen_samples_list[i][ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] + for i in range(cfg.k_label) + ] # get a batch from each label else: # 'G' gen_samples_list = [ self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + for i in range(cfg.k_label) + ] return real_samples_list, gen_samples_list def prepare_eval_real_data(self): """Prepare evaluation real data, contains k_label batches of data""" with torch.no_grad(): - self.eval_real_samples = [torch.cat( - [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) for i in range(cfg.k_label)] + self.eval_real_samples = [ + torch.cat( + [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], + cfg.vocab_size, + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_real_samples = [self.eval_real_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_real_samples = [ + self.eval_real_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_real = [self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_real = [ + self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label) + ] def prepare_eval_fake_data(self): """Prepare evaluation fake data, contains k_label batches of data""" with torch.no_grad(): - self.eval_fake_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + label_i=i, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_fake_samples = [self.eval_fake_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.eval_fake_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_fake = [self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_fake = [ + self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label) + ] @staticmethod def get_evo_temp(cur_step): @@ -547,14 +737,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) diff --git a/instructor/oracle_data/cot_instructor.py b/instructor/oracle_data/cot_instructor.py index b435ce22..4cc7958e 100644 --- a/instructor/oracle_data/cot_instructor.py +++ b/instructor/oracle_data/cot_instructor.py @@ -4,7 +4,7 @@ # @FileName : cot_instructor.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.CoT_D import Cot_D -from models.CoT_G import CoT_G +from models.discriminators.CoT_D import Cot_D +from models.generators.CoT_G import CoT_G from utils.data_loader import GenDataIter @@ -24,10 +24,22 @@ def __init__(self, opt): super(CoTInstructor, self).__init__(opt) # generator, discriminator - self.gen = CoT_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = Cot_D(cfg.gen_embed_dim * 2, cfg.gen_hidden_dim * 2, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) # embed_dim and hidden_dim is larger + self.gen = CoT_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = Cot_D( + cfg.gen_embed_dim * 2, + cfg.gen_hidden_dim * 2, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) # embed_dim and hidden_dim is larger self.init_model() # Optimizer @@ -39,30 +51,32 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for epoch in progress: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.train_mediator(epoch, cfg.ADV_d_step) # Discriminator - progress.set_description('g_loss: %.4f, d_loss: %.4f' % (g_loss, d_loss)) + progress.set_description("g_loss: %.4f, d_loss: %.4f" % (g_loss, d_loss)) if epoch % cfg.adv_log_step == 0 or epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV]: epoch = %d, %s' % (epoch, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV]: epoch = %d, %s" % (epoch, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', epoch) + self._save("ADV", epoch) torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -74,16 +88,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -92,7 +110,9 @@ def adv_train_generator(self, g_step): """ g_loss = [] for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.dis(inp, self.dis.init_hidden(cfg.batch_size)) @@ -109,9 +129,13 @@ def train_mediator(self, cur_epoch, d_step): d_loss = [] for step in range(d_step): # prepare loader for training - real = list(self.oracle_data.loader)[cur_epoch % len(self.oracle_data.loader)] # traverse all real data - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + real = list(self.oracle_data.loader)[ + cur_epoch % len(self.oracle_data.loader) + ] # traverse all real data + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) if cfg.CUDA: real_inp, real_tar = real_inp.cuda(), real_tar.cuda() diff --git a/instructor/oracle_data/dgsan_instructor.py b/instructor/oracle_data/dgsan_instructor.py index 4a9bbdbe..b23bc9ac 100644 --- a/instructor/oracle_data/dgsan_instructor.py +++ b/instructor/oracle_data/dgsan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dgsan_instructor.py # @Time : Created at 2020/4/12 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy import numpy as np @@ -16,7 +16,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.DGSAN_G import DGSAN_G +from models.generators.DGSAN_G import DGSAN_G from utils.data_loader import GenDataIter from utils.helpers import create_oracle @@ -26,10 +26,22 @@ def __init__(self, opt): super(DGSANInstructor, self).__init__(opt) # generator - self.gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.old_gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.old_gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -41,11 +53,21 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -55,14 +77,14 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) progress = tqdm(range(cfg.ADV_train_epoch)) @@ -70,16 +92,21 @@ def _run(self): g_loss = self.adv_train_generator() self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) - progress.set_description('g_loss: %.4f' % g_loss) + progress.set_description("g_loss: %.4f" % g_loss) - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): self.log.info( - '[ADV]: epoch: %d, g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + "[ADV]: epoch: %d, g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -91,26 +118,35 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self): g_loss = [] gen_data = GenDataIter(self.old_gen.sample(cfg.samples_num, cfg.batch_size)) for (real, fake) in zip(self.oracle_data.loader, gen_data.loader): - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = fake['input'], fake['target'] + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = fake["input"], fake["target"] if cfg.CUDA: - real_inp, real_tar, fake_inp, fake_tar = real_inp.cuda(), real_tar.cuda(), fake_inp.cuda(), fake_tar.cuda() + real_inp, real_tar, fake_inp, fake_tar = ( + real_inp.cuda(), + real_tar.cuda(), + fake_inp.cuda(), + fake_tar.cuda(), + ) # ===Train=== real_new_pred = self.cal_pred(self.gen, real_inp, real_tar) @@ -119,8 +155,12 @@ def adv_train_generator(self): fake_old_pred = self.cal_pred(self.old_gen, fake_inp, fake_tar) eps = 0 - real_loss = -torch.sum(torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps)) - fake_loss = -torch.sum(torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps)) + real_loss = -torch.sum( + torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps) + ) + fake_loss = -torch.sum( + torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps) + ) adv_loss = real_loss + fake_loss self.optimize(self.gen_adv_opt, adv_loss) diff --git a/instructor/oracle_data/dpgan_instructor.py b/instructor/oracle_data/dpgan_instructor.py index b0fe2597..8ef2fc1a 100644 --- a/instructor/oracle_data/dpgan_instructor.py +++ b/instructor/oracle_data/dpgan_instructor.py @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.DPGAN_D import DPGAN_D -from models.DPGAN_G import DPGAN_G +from models.discriminators.DPGAN_D import DPGAN_D +from models.generators.DPGAN_G import DPGAN_G class DPGANInstructor(BasicInstructor): @@ -21,10 +21,22 @@ def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DPGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = DPGAN_D( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -36,40 +48,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') - self.train_discriminator(cfg.d_step, cfg.d_epoch, 'MLE') + self.log.info("Starting Discriminator Training...") + self.train_discriminator(cfg.d_step, cfg.d_epoch, "MLE") if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -81,16 +102,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -100,13 +125,15 @@ def adv_train_generator(self, g_step): """ discount_rate = 1 total_g_loss = 0 - dis_count_list = [discount_rate ** i for i in range(cfg.max_seq_len)] - dis_count_matrix = torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + dis_count_list = [discount_rate**i for i in range(cfg.max_seq_len)] + dis_count_matrix = ( + torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + ) if cfg.CUDA: dis_count_matrix = dis_count_matrix.cuda() for step in range(g_step): - inp = self.oracle_data.random_batch()['input'] + inp = self.oracle_data.random_batch()["input"] if cfg.CUDA: inp = inp.cuda() @@ -124,9 +151,11 @@ def adv_train_generator(self, g_step): # ===Test=== self.log.info( - '[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True))) + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -147,8 +176,10 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): # ===Test=== pos_reward, neg_reward = self.eval_dis(self.dis, pos_val, neg_val) - self.log.info('[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f,' % ( - phase, step, pos_reward.item(), neg_reward.item())) + self.log.info( + "[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f," + % (phase, step, pos_reward.item(), neg_reward.item()) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) @@ -162,8 +193,8 @@ def train_dis_epoch(self, model, pos_samples, neg_samples, optimizer): num_samples = pos_samples.size(0) num_batch = num_samples // cfg.batch_size for i in range(num_batch): - pos_sample = pos_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] - neg_sample = neg_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] + pos_sample = pos_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] + neg_sample = neg_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] _, pos_reward = model.getReward(pos_sample) _, neg_reward = model.getReward(neg_sample) diff --git a/instructor/oracle_data/evogan_instructor.py b/instructor/oracle_data/evogan_instructor.py index 32793a0b..fa07c569 100644 --- a/instructor/oracle_data/evogan_instructor.py +++ b/instructor/oracle_data/evogan_instructor.py @@ -18,8 +18,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from metrics.nll import NLL -from models.EvoGAN_D import EvoGAN_D -from models.EvoGAN_G import EvoGAN_G +from models.discriminators.EvoGAN_D import EvoGAN_D +from models.generators.EvoGAN_G import EvoGAN_G from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss from utils.helpers import get_fixed_temperature, get_losses, create_oracle @@ -30,13 +30,39 @@ def __init__(self, opt): super(EvoGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = EvoGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = EvoGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -44,37 +70,59 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) def init_model(self): if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() - self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path, map_location='cuda:%d' % cfg.device)) + self.oracle.load_state_dict( + torch.load( + cfg.oracle_state_dict_path, map_location="cuda:%d" % cfg.device + ) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.oracle = self.oracle.cuda() self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() def load_gen(self, parent, parent_opt, mle=False): @@ -89,44 +137,72 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) # self.log.info('[ADV] epoch %d: temp = %.4f' % (adv_epoch, self.gen.temperature.item())) # self.log.info(fit_score[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() @@ -140,17 +216,21 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def evolve_generator(self, evo_g_step): @@ -167,12 +247,16 @@ def evolve_generator(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -206,7 +290,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -230,16 +316,20 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) # get evo temp - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -255,8 +345,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -308,13 +400,17 @@ def evolve_generator_population(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) # evaluate all parents - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): self.load_gen(parent, parent_opt) self.prepare_eval_fake_data() Fq, Fd, score = self.evaluation(cfg.eval_type) @@ -328,7 +424,9 @@ def evolve_generator_population(self, evo_g_step): # randomly choose a parent, variation target_idx = random.randint(0, len(self.parents) - 1) for j, criterionG in enumerate(self.G_criterion): - self.load_gen(self.parents[target_idx], self.parent_adv_opts[target_idx]) # load generator + self.load_gen( + self.parents[target_idx], self.parent_adv_opts[target_idx] + ) # load generator # Variation self.variation(evo_g_step, criterionG) @@ -344,7 +442,9 @@ def evolve_generator_population(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation.append(criterionG.loss_mode) @@ -356,8 +456,12 @@ def evolve_generator_population(self, evo_g_step): def evolve_discriminator(self, evo_d_step): total_loss = 0 for step in range(evo_d_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() - gen_samples = self.best_fake_samples[step * cfg.batch_size:(step + 1) * cfg.batch_size] + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() + gen_samples = self.best_fake_samples[ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -388,7 +492,9 @@ def variation(self, g_step, criterionG): # mixture variation: double loss rand_w = torch.rand(1).cuda() cri_1, cri_2 = criterionG - g_loss = rand_w * cri_1(self.d_out_real, d_out_fake) + (1 - rand_w) * cri_2(self.d_out_real, d_out_fake) + g_loss = rand_w * cri_1(self.d_out_real, d_out_fake) + ( + 1 - rand_w + ) * cri_2(self.d_out_real, d_out_fake) # all loss # rand_w = F.softmax(torch.rand(len(criterionG)).cuda(), dim=0) @@ -407,7 +513,9 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) + eval_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Fd @@ -416,18 +524,24 @@ def evaluation(self, eval_type): else: Fd = 0 - if eval_type == 'standard': + if eval_type == "standard": Fq = self.eval_d_out_fake.mean().cpu().item() - elif eval_type == 'rsgan': - g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') + elif eval_type == "rsgan": + g_loss, d_loss = get_losses( + self.eval_d_out_real, self.eval_d_out_fake, "rsgan" + ) Fq = d_loss.item() - elif eval_type == 'nll': + elif eval_type == "nll": if cfg.lambda_fq != 0: - Fq = -NLL.cal_nll(self.oracle, gen_data.loader, self.mle_criterion) # NLL_Oracle + Fq = -NLL.cal_nll( + self.oracle, gen_data.loader, self.mle_criterion + ) # NLL_Oracle else: Fq = 0 - elif eval_type == 'Ra': - g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() + elif eval_type == "Ra": + g_loss = torch.sigmoid( + self.eval_d_out_fake - torch.mean(self.eval_d_out_real) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -438,22 +552,31 @@ def evaluation(self, eval_type): def prepare_eval_real_data(self): with torch.no_grad(): self.eval_real_samples = torch.cat( - [F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) + [ + F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) if cfg.CUDA: self.eval_real_samples = self.eval_real_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_real = self.dis(self.eval_real_samples) def prepare_eval_fake_data(self): with torch.no_grad(): - self.eval_fake_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True) + self.eval_fake_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + ) if cfg.CUDA: self.eval_fake_samples = self.eval_fake_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_fake = self.dis(self.eval_fake_samples) @staticmethod @@ -463,14 +586,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) # three temp diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py new file mode 100644 index 00000000..b56ba76a --- /dev/null +++ b/instructor/oracle_data/fixem_instructor.py @@ -0,0 +1,58 @@ +import os +import random +from itertools import chain + +from pathlib import Path +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm, trange + + +import config as cfg +from instructor.oracle_data.instructor import BasicInstructor +from instructor.real_data.fixem_instructor import ( + FixemGANInstructor as RealDataFixemGANInstructor, +) +from utils.gan_loss import GANLoss +from utils.text_process import text_file_iterator +from utils.data_loader import DataSupplier, GANDataset +from utils.nn_helpers import create_noise, number_of_parameters +from utils.helpers import create_oracle +from metrics.nll import NLL +from utils.create_embeddings import EmbeddingsTrainer, load_embedding +from models.Oracle import Oracle +from models.generators.FixemGAN_G import Generator +from models.discriminators.FixemGAN_D import Discriminator + +# afterwards: +# check target real/fake to be right (Uniform or const) +# random data portion generator - data supplier sample from randomint + + +class FixemGANInstructor(BasicInstructor, RealDataFixemGANInstructor): + def __init__(self, opt): + self.oracle = Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) + if cfg.oracle_pretrain: + if not os.path.exists(cfg.oracle_state_dict_path): + create_oracle() + self.oracle.load_state_dict( + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) + + if cfg.CUDA: + self.oracle = self.oracle.cuda() + + super().__init__(opt) + + def build_embedding(self): + # train embedding on available dataset or oracle + self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + self.log.info("Will train new one, it may take a while...") + sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index a0939c11..63af96e7 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -4,36 +4,53 @@ # @FileName : instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np import os import torch import torch.nn as nn +import wandb import config as cfg +from metrics.bleu import BLEU +from metrics.clas_acc import ACC +from metrics.ioc import IOC +from metrics.gpt_nll import GPTNLL from metrics.nll import NLL +from metrics.ppl import PPL +from metrics.dummy import Dummy from models.Oracle import Oracle from utils.data_loader import GenDataIter from utils.data_utils import create_multi_oracle from utils.helpers import Signal, create_logger, create_oracle, get_fixed_temperature -from utils.text_process import write_tensor +from utils.text_process import write_tensor, tensor_to_tokens class BasicInstructor: def __init__(self, opt): - self.log = create_logger(__name__, silent=False, to_disk=True, - log_file=cfg.log_filename if cfg.if_test - else [cfg.log_filename, cfg.save_root + 'log.txt']) + self.log = create_logger( + __name__, + silent=False, + to_disk=True, + log_file=cfg.log_filename + if cfg.if_test + else [cfg.log_filename, cfg.save_root + "log.txt"], + ) self.sig = Signal(cfg.signal_file) self.opt = opt # oracle, generator, discriminator - self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.oracle_list = [Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] + self.oracle = Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) + self.oracle_list = [ + Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) + for _ in range(cfg.k_label) + ] self.dis = None self.clas = None @@ -41,25 +58,49 @@ def __init__(self, opt): self.show_config() self.check_oracle() # Create Oracle models if not exist # DataLoader - self.oracle_samples = torch.load(cfg.oracle_samples_path.format(cfg.samples_num)) - self.oracle_samples_list = [torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) - for i in range(cfg.k_label)] + self.oracle_samples = torch.load( + cfg.oracle_samples_path.format(cfg.samples_num) + ) + self.oracle_samples_list = [ + torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) + for i in range(cfg.k_label) + ] self.oracle_data = GenDataIter(self.oracle_samples) - self.oracle_data_list = [GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label)] + self.oracle_data_list = [ + GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label) + ] # Criterion self.mle_criterion = nn.NLLLoss() self.dis_criterion = nn.CrossEntropyLoss() # Metrics - self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) - self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div] + # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight + self.nll_oracle = NLL( + "NLL_oracle", weight=-3, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA + ) + # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) + self.nll_gen = NLL("NLL_gen", weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) + self.nll_div = NLL("NLL_div", weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight + self.self_bleu = BLEU("Self-BLEU", weight=-3, gram=3, if_use=cfg.use_self_bleu) + # IOC, less-better, changes in range 0.8 - 2.0, smaller weight + self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.oracle_data) + # dummy, add constant value to overall score + self.dummy = Dummy(weight=1, value=5, if_use=True) + self.all_metrics = [ + self.nll_oracle, + self.nll_gen, + self.nll_div, + self.self_bleu, + self.ioc, + self.dummy, + ] def _run(self): - print('Nothing to run in Basic Instructor!') + print("Nothing to run in Basic Instructor!") pass def _test(self): @@ -70,15 +111,30 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -88,7 +144,7 @@ def init_model(self): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -104,7 +160,7 @@ def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -127,7 +183,7 @@ def eval_dis(model, data_loader, criterion): total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -157,11 +213,31 @@ def optimize(opt, loss, model=None, retain_graph=False): def show_config(self): """Show parser parameters settings""" - self.log.info(100 * '=') - self.log.info('> training arguments:') + self.log.info(100 * "=") + self.log.info("> training arguments:") for arg in vars(self.opt): - self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) - self.log.info(100 * '=') + self.log.info(">>> {0}: {1}".format(arg, getattr(self.opt, arg))) + self.log.info(100 * "=") + + def sample_for_metrics(self): + eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size) + ) + return gen_data, gen_tokens, gen_tokens_s + + def sample_for_metrics_with_label(self, label_i): + eval_samples = self.gen.sample( + cfg.samples_num, 4 * cfg.batch_size, label_i=label_i + ) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) + ) + return gen_data, gen_tokens, gen_tokens_s def cal_metrics(self, fmt_str=False): """ @@ -170,51 +246,95 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - gen_data = GenDataIter(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)) + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) self.nll_gen.reset(self.gen, self.oracle_data.loader) self.nll_div.reset(self.gen, gen_data.loader) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + + metrics = {metric.name: metric.get_score() for metric in self.all_metrics} + metrics.update( + { + "Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) + wandb.log(metrics) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) + return [metric.get_score() for metric in self.all_metrics] - def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + def cal_metrics_with_label(self, label_i, fmt_str=False): + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - gen_data = GenDataIter(eval_samples) + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics_with_label() # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader, label_i) self.nll_gen.reset(self.gen, self.oracle_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) + + metrics = { + f"label {label_i}_{metric.name}": metric.get_score() + for metric in self.all_metrics + } + metrics.update( + { + f"label {label_i} Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) + wandb.log(metrics) - return [metric.get_score() for metric in self.all_metrics] + if fmt_str: + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) + return metrics def comb_metrics(self, fmt_str=False): - all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] - all_scores = np.array(all_scores).T.tolist() # each row for each metric + all_scores = [ + self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label) + ] if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), score) - for (metric, score) in zip(self.all_metrics, all_scores)]) - return all_scores + return ", ".join( + [ + f"{name} = {[scores[name] for scores in all_scores]}" + for name in all_scores[0] + ] + ) + return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" - if phase != 'ADV': - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + if phase != "ADV": + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size) write_tensor(save_sample_path, samples) def update_temperature(self, i, N): - self.gen.temperature.data = torch.Tensor([get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) + self.gen.temperature.data = torch.Tensor( + [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)] + ) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda() @@ -224,17 +344,28 @@ def check_oracle(self): create_multi_oracle(cfg.k_label) # General text generation Oracle model - if not os.path.exists(cfg.oracle_samples_path.format(cfg.samples_num)) or not cfg.oracle_pretrain: + if ( + not os.path.exists(cfg.oracle_samples_path.format(cfg.samples_num)) + or not cfg.oracle_pretrain + ): create_oracle() # Category text generation Oracle models for i in range(cfg.k_label): - if not os.path.exists(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)): + if not os.path.exists( + cfg.multi_oracle_samples_path.format(i, cfg.samples_num) + ): create_multi_oracle(cfg.k_label) break # Load Oracle state dict - self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle.load_state_dict( + torch.load( + cfg.oracle_state_dict_path, map_location="cuda:{}".format(cfg.device) + ) + ) for i in range(cfg.k_label): oracle_path = cfg.multi_oracle_state_dict_path.format(i) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:{}".format(cfg.device)) + ) diff --git a/instructor/oracle_data/jsdgan_instructor.py b/instructor/oracle_data/jsdgan_instructor.py index e2264d54..ca4324d9 100644 --- a/instructor/oracle_data/jsdgan_instructor.py +++ b/instructor/oracle_data/jsdgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : JSDGAN_instructor.py # @Time : Created at 2019/11/16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import os import torch @@ -12,7 +12,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.JSDGAN_G import JSDGAN_G +from models.generators.JSDGAN_G import JSDGAN_G from utils.helpers import create_oracle @@ -21,8 +21,17 @@ def __init__(self, opt): super(JSDGANInstructor, self).__init__(opt) # generator - self.gen = JSDGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = JSDGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -33,11 +42,21 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -45,23 +64,29 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") for adv_epoch in range(cfg.ADV_train_epoch): g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -73,16 +98,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -94,7 +123,7 @@ def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): for i, data in enumerate(self.oracle_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() diff --git a/instructor/oracle_data/leakgan_instructor.py b/instructor/oracle_data/leakgan_instructor.py index 45903f84..91d0fcf3 100644 --- a/instructor/oracle_data/leakgan_instructor.py +++ b/instructor/oracle_data/leakgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : leakgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.LeakGAN_D import LeakGAN_D -from models.LeakGAN_G import LeakGAN_G +from models.discriminators.LeakGAN_D import LeakGAN_D +from models.generators.LeakGAN_G import LeakGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter from utils.text_process import write_tensor @@ -24,9 +24,19 @@ def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) - self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = LeakGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + cfg.goal_size, + cfg.step_size, + cfg.CUDA, + ) + self.dis = LeakGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # optimizer @@ -39,48 +49,63 @@ def __init__(self, opt): def _run(self): for inter_num in range(cfg.inter_epoch): - self.log.info('>>> Interleaved Round %d...' % inter_num) + self.log.info(">>> Interleaved Round %d..." % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # ===DISCRIMINATOR PRE-TRAINING=== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format( + cfg.pretrained_dis_path + ) + ) # ===GENERATOR MLE TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + ) + ) else: - self.log.info('>>> Stop by pre_signal! Skip to adversarial training...') + self.log.info(">>> Stop by pre_signal! Skip to adversarial training...") break # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -98,7 +123,7 @@ def pretrain_generator(self, epochs): # ===Train=== for i, data in enumerate(self.oracle_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -111,13 +136,20 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % ( - epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s" + % ( + epoch, + pre_mana_loss, + pre_work_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step, current_k=0): @@ -131,13 +163,15 @@ def adv_train_generator(self, g_step, current_k=0): adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): - gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis, - train=True) # !!! train=True, the only place + gen_samples = self.gen.sample( + cfg.batch_size, cfg.batch_size, self.dis, train=True + ) # !!! train=True, the only place inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA) # ===Train=== - rewards = rollout_func.get_reward_leakgan(target, cfg.rollout_num, self.dis, - current_k).cpu() # reward with MC search + rewards = rollout_func.get_reward_leakgan( + target, cfg.rollout_num, self.dis, current_k + ).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss(target, rewards, self.dis) # update parameters @@ -145,10 +179,16 @@ def adv_train_generator(self, g_step, current_k=0): adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # ===Test=== - self.log.info('[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % ( - adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) - - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + self.log.info( + "[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s" + % ( + adv_mana_loss / g_step, + adv_work_loss / g_step, + self.cal_metrics(fmt_str=True), + ) + ) + + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -161,23 +201,32 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for step in range(d_step): # prepare loader for training - pos_samples = self.oracle.sample(cfg.samples_num, cfg.batch_size) # re-sample the Oracle Data + pos_samples = self.oracle.sample( + cfg.samples_num, cfg.batch_size + ) # re-sample the Oracle Data neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) dis_data = DisDataIter(pos_samples, neg_samples) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) def cal_metrics(self, fmt_str=False): # Prepare data for evaluation - gen_data = GenDataIter(self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)) + gen_data = GenDataIter( + self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) + ) # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) @@ -185,12 +234,22 @@ def cal_metrics(self, fmt_str=False): self.nll_div.reset(self.gen, gen_data.loader, leak_dis=self.dis) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return ", ".join( + [ + "%s = %s" % (metric.name, metric.get_score()) + for metric in self.all_metrics + ] + ) else: return [metric.get_score() for metric in self.all_metrics] def _save(self, phase, epoch): - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tensor(save_sample_path, samples) diff --git a/instructor/oracle_data/maligan_instructor.py b/instructor/oracle_data/maligan_instructor.py index c3bd5b39..a7eef286 100644 --- a/instructor/oracle_data/maligan_instructor.py +++ b/instructor/oracle_data/maligan_instructor.py @@ -4,7 +4,7 @@ # @FileName : maligan_instructor.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.MaliGAN_D import MaliGAN_D -from models.MaliGAN_G import MaliGAN_G +from models.discriminators.MaliGAN_D import MaliGAN_D +from models.generators.MaliGAN_G import MaliGAN_G from utils.data_loader import GenDataIter, DisDataIter @@ -24,9 +24,17 @@ def __init__(self, opt): super(MaliGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = MaliGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = MaliGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = MaliGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = MaliGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -38,40 +46,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,16 +100,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -101,7 +122,9 @@ def adv_train_generator(self, g_step): """ total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.get_mali_reward(target) @@ -110,9 +133,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -131,13 +157,18 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/oracle_data/relgan_instructor.py b/instructor/oracle_data/relgan_instructor.py index 8c7b610b..1b5bf5df 100644 --- a/instructor/oracle_data/relgan_instructor.py +++ b/instructor/oracle_data/relgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : relgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F @@ -13,8 +13,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.RelGAN_D import RelGAN_D -from models.RelGAN_G import RelGAN_G +from models.discriminators.RelGAN_D import RelGAN_D +from models.generators.RelGAN_G import RelGAN_G from utils.helpers import get_fixed_temperature, get_losses @@ -23,10 +23,25 @@ def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, - gpu=cfg.CUDA) + self.gen = RelGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = RelGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -38,39 +53,50 @@ def __init__(self, opt): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator - self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature + self.update_temperature( + adv_epoch, cfg.ADV_train_epoch + ) # update temperature progress.set_description( - 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) + "g_loss: %.4f, d_loss: %.4f, temperature: %.4f" + % (g_loss, d_loss, self.gen.temperature) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( - adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s" + % (adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) progress.close() break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() @@ -84,23 +110,29 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -118,7 +150,9 @@ def adv_train_generator(self, g_step): def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -134,7 +168,9 @@ def adv_train_discriminator(self, d_step): return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): - self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) + self.gen.temperature = get_fixed_temperature( + cfg.temperature, i, N, cfg.temp_adpt + ) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): diff --git a/instructor/oracle_data/sentigan_instructor.py b/instructor/oracle_data/sentigan_instructor.py index 17dd6631..1c72b697 100644 --- a/instructor/oracle_data/sentigan_instructor.py +++ b/instructor/oracle_data/sentigan_instructor.py @@ -4,7 +4,7 @@ # @FileName : sentigan_instructor.py # @Time : Created at 2019-07-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import os @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from models.Oracle import Oracle -from models.SentiGAN_D import SentiGAN_D -from models.SentiGAN_G import SentiGAN_G +from models.discriminators.SentiGAN_D import SentiGAN_D +from models.generators.SentiGAN_G import SentiGAN_G from utils import rollout from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter @@ -28,16 +28,42 @@ def __init__(self, opt): super(SentiGANInstructor, self).__init__(opt) # generator, discriminator - self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - - self.gen_list = [SentiGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - self.dis = SentiGAN_D(cfg.k_label, cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.oracle_list = [ + Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + + self.gen_list = [ + SentiGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + self.dis = SentiGAN_D( + cfg.k_label, + cfg.dis_embed_dim, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer - self.gen_opt_list = [optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list] + self.gen_opt_list = [ + optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list + ] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) def init_model(self): @@ -46,17 +72,32 @@ def init_model(self): oracle_path = cfg.multi_oracle_state_dict_path.format(i) if not os.path.exists(oracle_path): create_multi_oracle(cfg.k_label) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:{}".format(cfg.device)) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.k_label): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) self.gen_list[i].load_state_dict( - torch.load(cfg.pretrained_gen_path + '%d' % i, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.pretrained_gen_path + "%d" % i, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -67,41 +108,57 @@ def init_model(self): def _run(self): # ===PRE-TRAIN GENERATOR=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), cfg.pretrained_gen_path + '%d' % i) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen_list[i].state_dict(), + cfg.pretrained_gen_path + "%d" % i, + ) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s', self.comb_metrics(fmt_str=True)) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s", self.comb_metrics(fmt_str=True)) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -114,18 +171,24 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: for i in range(cfg.k_label): - pre_loss = self.train_gen_epoch(self.gen_list[i], self.oracle_data_list[i].loader, - self.mle_criterion, self.gen_opt_list[i]) + pre_loss = self.train_gen_epoch( + self.gen_list[i], + self.oracle_data_list[i].loader, + self.mle_criterion, + self.gen_opt_list[i], + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: if i == cfg.k_label - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -137,7 +200,10 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), + gpu=cfg.CUDA, + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -146,9 +212,9 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True)) + self.log.info("[ADV-GEN]: %s", self.comb_metrics(fmt_str=True)) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -162,32 +228,43 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): fake_samples = [] for i in range(cfg.k_label): real_samples.append(self.oracle_samples_list[i]) - fake_samples.append(self.gen_list[i].sample(cfg.samples_num // cfg.k_label, 8 * cfg.batch_size)) + fake_samples.append( + self.gen_list[i].sample( + cfg.samples_num // cfg.k_label, 8 * cfg.batch_size + ) + ) dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples dis_data = CatClasDataIter(dis_samples_list) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f" + % (phase, step, d_loss, train_acc) + ) - if cfg.if_save and not cfg.if_test and phase == 'MLE': + if cfg.if_save and not cfg.if_test and phase == "MLE": torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" # Prepare data for evaluation - eval_samples = self.gen_list[label_i].sample(cfg.samples_num, 8 * cfg.batch_size) + eval_samples = self.gen_list[label_i].sample( + cfg.samples_num, 8 * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader) - self.nll_gen.reset(self.gen_list[label_i], self.oracle_data_list[label_i].loader) + self.nll_gen.reset( + self.gen_list[label_i], self.oracle_data_list[label_i].loader + ) self.nll_div.reset(self.gen_list[label_i], gen_data.loader) return [metric.get_score() for metric in self.all_metrics] @@ -195,8 +272,13 @@ def cal_metrics_with_label(self, label_i): def _save(self, phase, epoch): """Save model state dict and generator's samples""" for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), - cfg.save_model_root + 'gen{}_{}_{:05d}.pt'.format(i, phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_d{}_{}_{:05d}.txt'.format(i, phase, epoch) + torch.save( + self.gen_list[i].state_dict(), + cfg.save_model_root + "gen{}_{}_{:05d}.pt".format(i, phase, epoch), + ) + save_sample_path = ( + cfg.save_samples_root + + "samples_d{}_{}_{:05d}.txt".format(i, phase, epoch) + ) samples = self.gen_list[i].sample(cfg.batch_size, cfg.batch_size) write_tensor(save_sample_path, samples) diff --git a/instructor/oracle_data/seqgan_instructor.py b/instructor/oracle_data/seqgan_instructor.py index 00046b80..1a5026a5 100644 --- a/instructor/oracle_data/seqgan_instructor.py +++ b/instructor/oracle_data/seqgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : seqgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.SeqGAN_D import SeqGAN_D -from models.SeqGAN_G import SeqGAN_G +from models.discriminators.SeqGAN_D import SeqGAN_D +from models.generators.SeqGAN_G import SeqGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter @@ -23,9 +23,17 @@ def __init__(self, opt): super(SeqGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = SeqGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = SeqGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = SeqGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = SeqGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -37,40 +45,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -82,16 +99,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -132,13 +158,18 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/real_data/catgan_instructor.py b/instructor/real_data/catgan_instructor.py index 094c0c9c..0afa4a67 100644 --- a/instructor/real_data/catgan_instructor.py +++ b/instructor/real_data/catgan_instructor.py @@ -17,8 +17,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor from metrics.nll import NLL -from models.CatGAN_D import CatGAN_D, CatGAN_C -from models.CatGAN_G import CatGAN_G +from models.discriminators.CatGAN_D import CatGAN_D, CatGAN_C +from models.generators.CatGAN_G import CatGAN_G from utils.cat_data_loader import CatGenDataIter from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss @@ -27,21 +27,54 @@ class CatGANInstructor(BasicInstructor): - def __init__(self, opt): super(CatGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, - gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = CatGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) - self.clas = CatGAN_C(cfg.k_label, cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.extend_vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = CatGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.clas = CatGAN_C( + cfg.k_label, + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.extend_vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -50,14 +83,19 @@ def __init__(self, opt): self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) self.clas_opt = optim.Adam(self.clas.parameters(), lr=cfg.clas_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) # DataLoader self.all_train_data = CatGenDataIter(self.train_samples_list) @@ -68,13 +106,21 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() self.clas = self.clas.cuda() @@ -90,19 +136,29 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===Pre-train Classifier with real data=== if cfg.use_clas_acc: - self.log.info('Start training Classifier...') + self.log.info("Start training Classifier...") self.train_classifier(cfg.PRE_clas_epoch) # ===Pre-train Generator=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===Adv-train=== progress = tqdm(range(cfg.ADV_train_epoch)) @@ -110,27 +166,45 @@ def _run(self): if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # ===Test=== - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.comb_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: for label_i in range(cfg.k_label): - self._save('ADV', adv_epoch, label_i) + self._save("ADV", adv_epoch, label_i) def _test(self): - self.log.debug('>>> Begin test...') + self.log.debug(">>> Begin test...") self._run() pass @@ -141,17 +215,20 @@ def pretrain_generator(self, epochs): """ for epoch in range(epochs): # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.all_train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.all_train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if not cfg.if_test and cfg.if_save: for label_i in range(cfg.k_label): - self._save('MLE', epoch, label_i) + self._save("MLE", epoch, label_i) def evolve_generator(self, evo_g_step): # evaluation real data @@ -167,13 +244,21 @@ def evolve_generator(self, evo_g_step): # all child share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -198,7 +283,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -222,17 +309,25 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -248,8 +343,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -287,15 +384,21 @@ def evolve_discriminator(self, evo_d_step): global dc_loss, dd_loss, d_loss total_loss = [] - all_gen_samples_list = list(map(self.merge, *self.best_fake_samples)) # merge each label of data - self.all_gen_samples_list = self.shuffle_eval_samples(all_gen_samples_list) # shuffle data + all_gen_samples_list = list( + map(self.merge, *self.best_fake_samples) + ) # merge each label of data + self.all_gen_samples_list = self.shuffle_eval_samples( + all_gen_samples_list + ) # shuffle data for step in range(evo_d_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('D', step) + dis_real_samples, dis_gen_samples = self.prepare_train_data("D", step) d_loss = 0 all_d_out_real = [] all_d_out_fake = [] - for (real_samples, fake_samples) in zip(dis_real_samples, dis_gen_samples): # for each label samples + for (real_samples, fake_samples) in zip( + dis_real_samples, dis_gen_samples + ): # for each label samples d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) d_loss += self.D_criterion(d_out_real, d_out_fake) @@ -320,14 +423,16 @@ def variation(self, g_step, criterionG): """Optimize one child (Generator)""" total_loss = [] for step in range(g_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('G') + dis_real_samples, dis_gen_samples = self.prepare_train_data("G") # ===Train=== g_loss = 0 all_d_out_real = [] all_d_out_fake = [] # for i, (real_samples, fake_samples) in enumerate(zip(dis_real_samples, dis_gen_samples)): - for i, (d_out_real, fake_samples) in enumerate(zip(self.d_out_real, dis_gen_samples)): # share real + for i, (d_out_real, fake_samples) in enumerate( + zip(self.d_out_real, dis_gen_samples) + ): # share real # d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) g_loss += criterionG(d_out_real, d_out_fake) @@ -350,30 +455,40 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) for i - in range(cfg.k_label)] + eval_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i + ) + for i in range(cfg.k_label) + ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) - nll_div.append(NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) + nll_div.append( + NLL.cal_nll_with_label( + self.gen, gen_data.loader, label_i, self.mle_criterion + ) + ) Fd = sum(nll_div) else: Fd = 0 # Fq - if 'bleu' in eval_type: + if "bleu" in eval_type: bleu_score = [] for i in range(cfg.k_label): bleu_score.append(self.bleu[i].get_score(given_gram=int(eval_type[-1]))) Fq = sum(bleu_score) - elif 'Ra' in eval_type: + elif "Ra" in eval_type: g_loss = 0 for i in range(cfg.k_label): - g_loss += torch.sigmoid(self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() + g_loss += torch.sigmoid( + self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i]) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -384,7 +499,7 @@ def evaluation(self, eval_type): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target, label = data['input'], data['target'], data['label'] + inp, target, label = data["input"], data["target"], data["label"] if cfg.CUDA: inp, target, label = inp.cuda(), target.cuda(), label.cuda() @@ -397,8 +512,13 @@ def train_gen_epoch(self, model, data_loader, criterion, optimizer): def _save(self, phase, epoch, label_i=None): assert type(label_i) == int - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_c{}_{}_{:05d}.txt'.format(label_i, phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_c{}_{}_{:05d}.txt".format( + label_i, phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, label_i=label_i) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) @@ -409,50 +529,86 @@ def merge(*args): def shuffle_eval_samples(self, all_eval_samples): temp = [] for i in range(cfg.k_label): - temp.append(all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))]) + temp.append( + all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))] + ) return temp def prepare_train_data(self, which, step=None): """Prepare train data for both Generator and Discriminator, each samples_list contains k_label batches of data""" - assert which == 'D' or which == 'G', 'only support for D and G!!' + assert which == "D" or which == "G", "only support for D and G!!" real_samples_list = [ - F.one_hot(self.train_data_list[i].random_batch()['target'][:cfg.batch_size], - cfg.vocab_size).float().cuda() - for i in range(cfg.k_label)] - if which == 'D': - assert step is not None, 'missing step' - gen_samples_list = [self.all_gen_samples_list[i][step * cfg.batch_size:(step + 1) * cfg.batch_size] - for i in range(cfg.k_label)] # get a batch from each label + F.one_hot( + self.train_data_list[i].random_batch()["target"][: cfg.batch_size], + cfg.vocab_size, + ) + .float() + .cuda() + for i in range(cfg.k_label) + ] + if which == "D": + assert step is not None, "missing step" + gen_samples_list = [ + self.all_gen_samples_list[i][ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] + for i in range(cfg.k_label) + ] # get a batch from each label else: # 'G' gen_samples_list = [ self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + for i in range(cfg.k_label) + ] return real_samples_list, gen_samples_list def prepare_eval_real_data(self): """Prepare evaluation real data, contains k_label batches of data""" with torch.no_grad(): - self.eval_real_samples = [torch.cat( - [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) for i in range(cfg.k_label)] + self.eval_real_samples = [ + torch.cat( + [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], + cfg.vocab_size, + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_real_samples = [self.eval_real_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_real_samples = [ + self.eval_real_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_real = [self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_real = [ + self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label) + ] def prepare_eval_fake_data(self): """Prepare evaluation fake data, contains k_label batches of data""" with torch.no_grad(): - self.eval_fake_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + label_i=i, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_fake_samples = [self.eval_fake_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.eval_fake_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_fake = [self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_fake = [ + self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label) + ] @staticmethod def get_evo_temp(cur_step): @@ -461,14 +617,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) diff --git a/instructor/real_data/cot_instructor.py b/instructor/real_data/cot_instructor.py index b7ad3d24..01ba17e7 100644 --- a/instructor/real_data/cot_instructor.py +++ b/instructor/real_data/cot_instructor.py @@ -4,7 +4,7 @@ # @FileName : cot_instructor.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -15,8 +15,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.CoT_D import Cot_D -from models.CoT_G import CoT_G +from models.discriminators.CoT_D import Cot_D +from models.generators.CoT_G import CoT_G from utils.data_loader import GenDataIter @@ -25,10 +25,22 @@ def __init__(self, opt): super(CoTInstructor, self).__init__(opt) # generator, discriminator - self.gen = CoT_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = Cot_D(cfg.gen_embed_dim * 2, cfg.gen_hidden_dim * 2, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) # embed_dim and hidden_dim is larger + self.gen = CoT_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = Cot_D( + cfg.gen_embed_dim * 2, + cfg.gen_hidden_dim * 2, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) # embed_dim and hidden_dim is larger self.init_model() # Optimizer @@ -40,30 +52,32 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for epoch in progress: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.train_mediator(epoch, cfg.ADV_d_step) # Discriminator - progress.set_description('g_loss: %.4f, d_loss: %.4f' % (g_loss, d_loss)) + progress.set_description("g_loss: %.4f, d_loss: %.4f" % (g_loss, d_loss)) if epoch % cfg.adv_log_step == 0 or epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV]: epoch = %d, %s' % (epoch, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV]: epoch = %d, %s" % (epoch, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', epoch) + self._save("ADV", epoch) torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -75,16 +89,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -93,7 +111,9 @@ def adv_train_generator(self, g_step): """ g_loss = [] for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.dis(inp, self.dis.init_hidden(cfg.batch_size)) @@ -110,9 +130,13 @@ def train_mediator(self, cur_epoch, d_step): d_loss = [] for step in range(d_step): # prepare loader for training - real = list(self.train_data.loader)[cur_epoch % len(self.train_data.loader)] # traverse all real data - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + real = list(self.train_data.loader)[ + cur_epoch % len(self.train_data.loader) + ] # traverse all real data + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) if cfg.CUDA: real_inp, real_tar = real_inp.cuda(), real_tar.cuda() diff --git a/instructor/real_data/dgsan_instructor.py b/instructor/real_data/dgsan_instructor.py index 5b018c87..9c250d1b 100644 --- a/instructor/real_data/dgsan_instructor.py +++ b/instructor/real_data/dgsan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dgsan_instructor.py # @Time : Created at 2020/4/16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy @@ -16,7 +16,7 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.DGSAN_G import DGSAN_G +from models.generators.DGSAN_G import DGSAN_G from utils.data_loader import GenDataIter @@ -25,10 +25,22 @@ def __init__(self, opt): super(DGSANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.old_gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.old_gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -37,8 +49,14 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -47,14 +65,14 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) progress = tqdm(range(cfg.ADV_train_epoch)) @@ -62,16 +80,21 @@ def _run(self): g_loss = self.adv_train_generator() self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) - progress.set_description('g_loss: %.4f' % g_loss) + progress.set_description("g_loss: %.4f" % g_loss) - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): self.log.info( - '[ADV]: epoch: %d, g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + "[ADV]: epoch: %d, g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,26 +106,35 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self): g_loss = [] gen_data = GenDataIter(self.old_gen.sample(cfg.samples_num, cfg.batch_size)) for (real, fake) in zip(self.train_data.loader, gen_data.loader): - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = fake['input'], fake['target'] + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = fake["input"], fake["target"] if cfg.CUDA: - real_inp, real_tar, fake_inp, fake_tar = real_inp.cuda(), real_tar.cuda(), fake_inp.cuda(), fake_tar.cuda() + real_inp, real_tar, fake_inp, fake_tar = ( + real_inp.cuda(), + real_tar.cuda(), + fake_inp.cuda(), + fake_tar.cuda(), + ) # ===Train=== real_new_pred = self.cal_pred(self.gen, real_inp, real_tar) @@ -111,8 +143,12 @@ def adv_train_generator(self): fake_old_pred = self.cal_pred(self.old_gen, fake_inp, fake_tar) eps = 0 - real_loss = -torch.sum(torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps)) - fake_loss = -torch.sum(torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps)) + real_loss = -torch.sum( + torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps) + ) + fake_loss = -torch.sum( + torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps) + ) adv_loss = real_loss + fake_loss self.optimize(self.gen_adv_opt, adv_loss) diff --git a/instructor/real_data/dpgan_instructor.py b/instructor/real_data/dpgan_instructor.py index 40c95e15..d4ad3293 100644 --- a/instructor/real_data/dpgan_instructor.py +++ b/instructor/real_data/dpgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dpgan_instructor.py # @Time : Created at 2019/12/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.DPGAN_D import DPGAN_D -from models.DPGAN_G import DPGAN_G +from models.discriminators.DPGAN_D import DPGAN_D +from models.generators.DPGAN_G import DPGAN_G class DPGANInstructor(BasicInstructor): @@ -21,10 +21,22 @@ def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DPGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = DPGAN_D( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -36,40 +48,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') - self.train_discriminator(cfg.d_step, cfg.d_epoch, 'MLE') + self.log.info("Starting Discriminator Training...") + self.train_discriminator(cfg.d_step, cfg.d_epoch, "MLE") if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -81,16 +102,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -100,13 +125,15 @@ def adv_train_generator(self, g_step): """ discount_rate = 1 total_g_loss = 0 - dis_count_list = [discount_rate ** i for i in range(cfg.max_seq_len)] - dis_count_matrix = torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + dis_count_list = [discount_rate**i for i in range(cfg.max_seq_len)] + dis_count_matrix = ( + torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + ) if cfg.CUDA: dis_count_matrix = dis_count_matrix.cuda() for step in range(g_step): - inp = self.train_data.random_batch()['input'] + inp = self.train_data.random_batch()["input"] if cfg.CUDA: inp = inp.cuda() @@ -124,9 +151,11 @@ def adv_train_generator(self, g_step): # ===Test=== self.log.info( - '[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True))) + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -140,11 +169,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): pos_reward, neg_reward = 0, 0 for epoch in range(d_epoch): # ===Train=== - pos_reward, neg_reward = self.train_dis_epoch(self.dis, pos_samples, neg_samples, self.dis_opt) + pos_reward, neg_reward = self.train_dis_epoch( + self.dis, pos_samples, neg_samples, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f,' % ( - phase, step, pos_reward, neg_reward)) + self.log.info( + "[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f," + % (phase, step, pos_reward, neg_reward) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) @@ -159,8 +192,8 @@ def train_dis_epoch(self, model, pos_samples, neg_samples, optimizer): num_samples = pos_samples.size(0) num_batch = num_samples // cfg.batch_size for i in range(num_batch): - pos_sample = pos_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] - neg_sample = neg_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] + pos_sample = pos_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] + neg_sample = neg_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] _, pos_reward = model.getReward(pos_sample) _, neg_reward = model.getReward(neg_sample) diff --git a/instructor/real_data/evogan_instructor.py b/instructor/real_data/evogan_instructor.py index 5fcee9e9..faf2cd13 100644 --- a/instructor/real_data/evogan_instructor.py +++ b/instructor/real_data/evogan_instructor.py @@ -12,13 +12,13 @@ import torch import torch.nn.functional as F import torch.optim as optim -from tqdm import tqdm +from tqdm import tqdm, trange import config as cfg from instructor.real_data.instructor import BasicInstructor from metrics.nll import NLL -from models.EvoGAN_D import EvoGAN_D -from models.EvoGAN_G import EvoGAN_G +from models.discriminators.EvoGAN_D import EvoGAN_D +from models.generators.EvoGAN_G import EvoGAN_G from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss from utils.helpers import get_fixed_temperature, get_losses @@ -30,13 +30,39 @@ def __init__(self, opt): super(EvoGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = EvoGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = EvoGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -44,30 +70,48 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() def load_gen(self, parent, parent_opt, mle=False): @@ -82,42 +126,70 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - progress = tqdm(range(cfg.ADV_train_epoch)) + self.log.info("Starting Adversarial Training...") + progress = trange(cfg.ADV_train_epoch) for adv_epoch in progress: if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -126,21 +198,25 @@ def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ - for epoch in range(epochs): + for epoch in trange(epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def evolve_generator(self, evo_g_step): @@ -157,12 +233,16 @@ def evolve_generator(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -188,7 +268,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -212,16 +294,20 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) # get evo temp - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -237,8 +323,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -275,8 +363,12 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): def evolve_discriminator(self, evo_d_step): total_loss = 0 for step in range(evo_d_step): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() - gen_samples = self.best_fake_samples[step * cfg.batch_size:(step + 1) * cfg.batch_size] + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() + gen_samples = self.best_fake_samples[ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -312,7 +404,9 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) + eval_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Fd @@ -322,20 +416,26 @@ def evaluation(self, eval_type): Fd = 0 # Fq - if eval_type == 'standard': + if eval_type == "standard": Fq = self.eval_d_out_fake.mean().cpu().item() - elif eval_type == 'rsgan': - g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') + elif eval_type == "rsgan": + g_loss, d_loss = get_losses( + self.eval_d_out_real, self.eval_d_out_fake, "rsgan" + ) Fq = d_loss.item() - elif 'bleu' in eval_type: - self.bleu.reset(test_text=tensor_to_tokens(eval_samples, self.idx2word_dict)) + elif "bleu" in eval_type: + self.bleu.reset( + test_text=tensor_to_tokens(eval_samples, self.idx2word_dict) + ) if cfg.lambda_fq != 0: Fq = self.bleu.get_score(given_gram=int(eval_type[-1])) else: Fq = 0 - elif 'Ra' in eval_type: - g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() + elif "Ra" in eval_type: + g_loss = torch.sigmoid( + self.eval_d_out_fake - torch.mean(self.eval_d_out_real) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -346,22 +446,31 @@ def evaluation(self, eval_type): def prepare_eval_real_data(self): with torch.no_grad(): self.eval_real_samples = torch.cat( - [F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) + [ + F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) if cfg.CUDA: self.eval_real_samples = self.eval_real_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_real = self.dis(self.eval_real_samples) def prepare_eval_fake_data(self): with torch.no_grad(): - self.eval_fake_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True) + self.eval_fake_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + ) if cfg.CUDA: self.eval_fake_samples = self.eval_fake_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_fake = self.dis(self.eval_fake_samples) @staticmethod @@ -371,14 +480,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) # three temp diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py new file mode 100644 index 00000000..60a0b681 --- /dev/null +++ b/instructor/real_data/fixem_instructor.py @@ -0,0 +1,198 @@ +import os +import random +from itertools import chain + +from pathlib import Path +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from tqdm import trange, tqdm + + +import config as cfg +from instructor.real_data.instructor import BasicInstructor +from utils.gan_loss import GANLoss +from utils.text_process import text_file_iterator +from utils.data_loader import DataSupplier, GenDataIter, GANDataset +from utils.cat_data_loader import CatClasDataIter +from utils.nn_helpers import create_noise, number_of_parameters +from utils.create_embeddings import EmbeddingsTrainer, load_embedding +from models.generators.FixemGAN_G import Generator +from models.discriminators.FixemGAN_D import Discriminator + + +class FixemGANInstructor(BasicInstructor): + def __init__(self, opt): + super(FixemGANInstructor, self).__init__(opt) + # check if embeddings already exist + if not os.path.exists(cfg.pretrain_embedding_path): + # train embedding on available datasets + self.build_embedding() + + w2v = load_embedding(cfg.pretrain_embedding_path) + + if cfg.run_model == "fixemgan": + labels, train_data = zip( + *[(0, line) for line in text_file_iterator(cfg.train_data)] + ) + + if cfg.run_model == "cat_fixemgan": + labels, train_data = zip( + *chain( + *[ + [ + (i, line) + for line in text_file_iterator(cfg.cat_train_data.format(i)) + ] + for i in range(cfg.k_label) + ] + ) + ) + + self.train_data_supplier = DataSupplier( + train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch + ) + + self.dis = Discriminator(cfg.discriminator_complexity) + self.log.info( + f"discriminator total tranable parameters: {number_of_parameters(self.dis.parameters())}" + ) + self.gen = Generator( + cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size + ) + self.log.info( + f"generator total tranable parameters: {number_of_parameters(self.gen.parameters())}" + ) + + if cfg.CUDA: + self.dis = self.dis.cuda() + self.gen = self.gen.cuda() + + self.G_criterion = GANLoss( + cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA + ) + self.D_criterion = GANLoss( + cfg.loss_type, + which_net=None, + which_D=None, + target_real_label=0.8, + target_fake_label=0.2, + CUDA=cfg.CUDA, + ) + + def build_embedding(self): + self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + self.log.info("Will train new one, it may take a while...") + sources = list(Path(cfg.texts_pile).glob("*.txt")) + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() + + def generator_train_one_batch(self): + self.gen.optimizer.zero_grad() + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) + fakes = self.gen(*noise) + + real_fake_predicts, label_predicts = self.dis(fakes) + loss = self.G_criterion.G_loss_fixem( + real_fake_predicts, label_predicts, noise[1], fakes + ) + + loss.backward() + self.gen.optimizer.step() + + generator_acc = float( + np.array(real_fake_predicts.detach().cpu().numpy() > 0.5, dtype=int).mean() + ) + return generator_acc + + def discriminator_train_one_batch(self, real_vector, labels): + # important to have equal batch size for fake and real vectors + this_batch_size = real_vector.shape[0] + + # create input + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) + fake = self.gen(*noise).detach() + text_input_vectors = torch.cat((real_vector, fake)) + + # optmizer step + self.dis.optimizer.zero_grad() + real_fake_predicts, label_predicts = self.dis(text_input_vectors) + loss = self.D_criterion.D_loss_fixem( + real_fake_predicts, label_predicts[:this_batch_size], labels + ) + loss.backward() + self.dis.optimizer.step() + + real_fake_predicts = real_fake_predicts.clone().detach() + real_fake_predicts = real_fake_predicts.chunk( + 2 + ) # splitting to realand fake parks + + discriminator_acc = float( + torch.cat((real_fake_predicts[0] > 0.5, real_fake_predicts[1] < 0.5)).mean( + dtype=float + ) + ) + return discriminator_acc + + def _run(self): + for i in trange(cfg.max_epochs): + for labels, text_vector in tqdm(self.train_data_supplier, leave=False): + if cfg.CUDA: + labels, text_vector = labels.cuda(), text_vector.cuda() + discriminator_acc = self.discriminator_train_one_batch( + text_vector, labels + ) + + generator_acc = 1 - 2 * (discriminator_acc - 0.5) + # run the generator until generator acc not get high enought + while self.one_more_batch_for_generator(generator_acc): + generator_acc = self.generator_train_one_batch() + + if cfg.run_model == "fixemgan": + print("calculating_metrics") + scores = self.cal_metrics(fmt_str=True) + if cfg.run_model == "cat_fixemgan": + scores = "\n\n".join( + [ + self.cal_metrics_with_label(label_i=label_i, fmt_str=True) + for label_i in range(cfg.k_label) + ] + ) + self.log.info(f"epoch: {i}") + self.log.info(f"{scores}") + + def one_more_batch_for_generator( + self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 + ): + generator_acc = min(leave_in_generator_max, generator_acc) + generator_acc = max(leave_in_generator_min, generator_acc) + if random.random() > generator_acc: + return True + return False + + def sample_for_metrics(self): + gen_tokens = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens = [sample.split() for sample in gen_tokens] + gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] + return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s + + def sample_for_metrics_with_label(self, label_i): + gen_tokens = self.gen.sample( + cfg.samples_num, 8 * cfg.batch_size, label_i=label_i + ) + gen_tokens = [sample.split() for sample in gen_tokens] + gen_tokens_s = self.gen.sample( + cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i + ) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] + return ( + GenDataIter(gen_tokens), + gen_tokens, + gen_tokens_s, + CatClasDataIter([gen_tokens], label_i), + ) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 622c382a..25d641dd 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -4,18 +4,22 @@ # @FileName : instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np import torch import torch.nn as nn +import wandb import config as cfg from metrics.bleu import BLEU from metrics.clas_acc import ACC +from metrics.ioc import IOC +from metrics.gpt_nll import GPTNLL from metrics.nll import NLL from metrics.ppl import PPL +from metrics.dummy import Dummy from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter from utils.helpers import Signal, create_logger, get_fixed_temperature @@ -24,9 +28,14 @@ class BasicInstructor: def __init__(self, opt): - self.log = create_logger(__name__, silent=False, to_disk=True, - log_file=cfg.log_filename if cfg.if_test - else [cfg.log_filename, cfg.save_root + 'log.txt']) + self.log = create_logger( + __name__, + silent=False, + to_disk=True, + log_file=cfg.log_filename + if cfg.if_test + else [cfg.log_filename, cfg.save_root + "log.txt"], + ) self.sig = Signal(cfg.signal_file) self.opt = opt self.show_config() @@ -34,9 +43,14 @@ def __init__(self, opt): self.clas = None # load dictionary - self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) + if cfg.if_real_data: + self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) + else: + self.word2idx_dict, self.idx2word_dict = {}, {} # Dataloader + self.train_data = None + self.test_data = None try: self.train_data = GenDataIter(cfg.train_data) self.test_data = GenDataIter(cfg.test_data, if_test_data=True) @@ -44,14 +58,24 @@ def __init__(self, opt): pass try: - self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)] - self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in - range(cfg.k_label)] - self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in - range(cfg.k_label)] - - self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)] - self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)] + self.train_data_list = [ + GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label) + ] + self.test_data_list = [ + GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) + for i in range(cfg.k_label) + ] + self.clas_data_list = [ + GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) + for i in range(cfg.k_label) + ] + + self.train_samples_list = [ + self.train_data_list[i].target for i in range(cfg.k_label) + ] + self.clas_samples_list = [ + self.clas_data_list[i].target for i in range(cfg.k_label) + ] except: pass @@ -64,16 +88,41 @@ def __init__(self, opt): self.clas_opt = None # Metrics - self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu) - self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) - self.clas_acc = ACC(if_use=cfg.use_clas_acc) - self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl] + # bleu, more-better, changes in range 0.4 - 0.6, will have relatively high weight + self.bleu = BLEU("BLEU", weight=3, gram=3, if_use=cfg.use_bleu) + # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) + self.nll_gen = NLL("NLL_gen", weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) + self.nll_div = NLL("NLL_div", weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight + self.self_bleu = BLEU("Self-BLEU", weight=-3, gram=3, if_use=cfg.use_self_bleu) + # class-acc, more-better, changes in range 0.7 - 1.0, moderate weight + self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) + # IOC, less-better, changes in range 0.8 - 2.0, smaller weight + self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data) + # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight + self.nll_oracle = GPTNLL( + weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data + ) + # perplexity, less-better, changes in range 3 - 4, moderate weight (not in use) + self.ppl = PPL( + self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl + ) + # dummy, add constant value to overall score + self.dummy = Dummy(weight=1, value=5, if_use=True) + self.all_metrics = [ + self.bleu, + self.nll_gen, + self.nll_div, + self.self_bleu, + self.ioc, + self.nll_oracle, + self.ppl, + self.dummy, + ] def _run(self): - print('Nothing to run in Basic Instructor!') + print("Nothing to run in Basic Instructor!") pass def _test(self): @@ -82,11 +131,22 @@ def _test(self): def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pre-trained generator: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pre-trained generator: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -95,7 +155,7 @@ def init_model(self): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -111,7 +171,7 @@ def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -147,15 +207,28 @@ def train_classifier(self, epochs): max_acc = 0 best_clas = None for epoch in range(epochs): - c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader, self.clas_criterion, - self.clas_opt) - _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader, self.clas_criterion) + c_loss, c_acc = self.train_dis_epoch( + self.clas, clas_data.loader, self.clas_criterion, self.clas_opt + ) + _, eval_acc = self.eval_dis( + self.clas, eval_clas_data.loader, self.clas_criterion + ) if eval_acc > max_acc: - best_clas = copy.deepcopy(self.clas.state_dict()) # save the best classifier + best_clas = copy.deepcopy( + self.clas.state_dict() + ) # save the best classifier max_acc = eval_acc - self.log.info('[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f', - epoch, c_loss, c_acc, eval_acc, max_acc) - self.clas.load_state_dict(copy.deepcopy(best_clas)) # Reload the best classifier + self.log.info( + "[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f", + epoch, + c_loss, + c_acc, + eval_acc, + max_acc, + ) + self.clas.load_state_dict( + copy.deepcopy(best_clas) + ) # Reload the best classifier @staticmethod def eval_dis(model, data_loader, criterion): @@ -164,7 +237,7 @@ def eval_dis(model, data_loader, criterion): total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -193,11 +266,34 @@ def optimize(opt, loss, model=None, retain_graph=False): opt.step() def show_config(self): - self.log.info(100 * '=') - self.log.info('> training arguments:') + self.log.info(100 * "=") + self.log.info("> training arguments:") for arg in vars(self.opt): - self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) - self.log.info(100 * '=') + self.log.info(">>> {0}: {1}".format(arg, getattr(self.opt, arg))) + self.log.info(100 * "=") + + def sample_for_metrics(self): + eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size), + self.idx2word_dict, + ) + return gen_data, gen_tokens, gen_tokens_s + + def sample_for_metrics_with_label(self, label_i): + eval_samples = self.gen.sample( + cfg.samples_num, 8 * cfg.batch_size, label_i=label_i + ) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i), + self.idx2word_dict, + ) + clas_data = CatClasDataIter([eval_samples], label_i) + return gen_data, gen_tokens, gen_tokens_s, clas_data def cal_metrics(self, fmt_str=False): """ @@ -206,62 +302,107 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) - gen_data = GenDataIter(eval_samples) - gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) - + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() + print("sampled") # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - self.ppl.reset(gen_tokens) + self.ppl.reset(gen_tokens=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) + + print("all reset") + metrics = {metric.name: metric.get_score() for metric in self.all_metrics} + print("get_score called") + metrics.update( + { + "Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) + wandb.log(metrics) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) + return [metric.get_score() for metric in self.all_metrics] - def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + def cal_metrics_with_label(self, label_i, fmt_str=False): + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - gen_data = GenDataIter(eval_samples) - gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) - clas_data = CatClasDataIter([eval_samples], label_i) - + ( + gen_data, + gen_tokens, + gen_tokens_s, + clas_data, + ) = self.sample_for_metrics_with_label(label_i) # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) + self.bleu.reset( + test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens + ) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) + + metrics = { + f"label {label_i}_{metric.name}": metric.get_score() + for metric in self.all_metrics + } + metrics.update( + { + f"label {label_i} Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) + wandb.log(metrics) + if fmt_str: + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): - all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] - all_scores = np.array(all_scores).T.tolist() # each row for each metric + all_scores = [ + self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label) + ] if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), score) - for (metric, score) in zip(self.all_metrics, all_scores)]) - return all_scores + return ", ".join( + [ + f"{name} = {[scores[name] for scores in all_scores]}" + for name in all_scores[0] + ] + ) + return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" - if phase != 'ADV': - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + if phase != "ADV": + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) def update_temperature(self, i, N): - self.gen.temperature.data = torch.Tensor([get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) + self.gen.temperature.data = torch.Tensor( + [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)] + ) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda() diff --git a/instructor/real_data/jsdgan_instructor.py b/instructor/real_data/jsdgan_instructor.py index 6c3e58b0..8144d6a3 100644 --- a/instructor/real_data/jsdgan_instructor.py +++ b/instructor/real_data/jsdgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : JSDGAN_instructor.py # @Time : Created at 2019/11/25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,7 +12,7 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.JSDGAN_G import JSDGAN_G +from models.generators.JSDGAN_G import JSDGAN_G class JSDGANInstructor(BasicInstructor): @@ -20,8 +20,17 @@ def __init__(self, opt): super(JSDGANInstructor, self).__init__(opt) # generator - self.gen = JSDGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = JSDGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -29,8 +38,14 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -38,23 +53,29 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") for adv_epoch in range(cfg.ADV_train_epoch): g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -66,16 +87,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -87,7 +112,7 @@ def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): for i, data in enumerate(self.train_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() diff --git a/instructor/real_data/leakgan_instructor.py b/instructor/real_data/leakgan_instructor.py index 922a203f..cecf59eb 100644 --- a/instructor/real_data/leakgan_instructor.py +++ b/instructor/real_data/leakgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : leakgan_instructor.py # @Time : Created at 2019-06-05 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.LeakGAN_D import LeakGAN_D -from models.LeakGAN_G import LeakGAN_G +from models.discriminators.LeakGAN_D import LeakGAN_D +from models.generators.LeakGAN_G import LeakGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter from utils.text_process import tensor_to_tokens, write_tokens @@ -24,9 +24,19 @@ def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) - self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = LeakGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + cfg.goal_size, + cfg.step_size, + cfg.CUDA, + ) + self.dis = LeakGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # optimizer @@ -39,48 +49,63 @@ def __init__(self, opt): def _run(self): for inter_num in range(cfg.inter_epoch): - self.log.info('>>> Interleaved Round %d...' % inter_num) + self.log.info(">>> Interleaved Round %d..." % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # ===DISCRIMINATOR PRE-TRAINING=== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format( + cfg.pretrained_dis_path + ) + ) # ===GENERATOR MLE TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + ) + ) else: - self.log.info('>>> Stop by pre_signal! Skip to adversarial training...') + self.log.info(">>> Stop by pre_signal! Skip to adversarial training...") break # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -98,7 +123,7 @@ def pretrain_generator(self, epochs): # ===Train=== for i, data in enumerate(self.train_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -111,13 +136,20 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % ( - epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s" + % ( + epoch, + pre_mana_loss, + pre_work_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step, current_k=0): @@ -131,13 +163,15 @@ def adv_train_generator(self, g_step, current_k=0): adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): - gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis, - train=True) # !!! train=True, the only place + gen_samples = self.gen.sample( + cfg.batch_size, cfg.batch_size, self.dis, train=True + ) # !!! train=True, the only place inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA) # ===Train=== - rewards = rollout_func.get_reward_leakgan(target, cfg.rollout_num, self.dis, - current_k).cpu() # reward with MC search + rewards = rollout_func.get_reward_leakgan( + target, cfg.rollout_num, self.dis, current_k + ).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss(target, rewards, self.dis) # update parameters @@ -145,10 +179,16 @@ def adv_train_generator(self, g_step, current_k=0): adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # ===Test=== - self.log.info('[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % ( - adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) - - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + self.log.info( + "[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s" + % ( + adv_mana_loss / g_step, + adv_work_loss / g_step, + self.cal_metrics(fmt_str=True), + ) + ) + + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -162,12 +202,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) def cal_metrics(self, fmt_str=False): with torch.no_grad(): @@ -175,7 +218,9 @@ def cal_metrics(self, fmt_str=False): eval_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, cfg.batch_size, self.dis), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(200, cfg.batch_size, self.dis), self.idx2word_dict + ) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) @@ -185,12 +230,22 @@ def cal_metrics(self, fmt_str=False): self.ppl.reset(gen_tokens) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return ", ".join( + [ + "%s = %s" % (metric.name, metric.get_score()) + for metric in self.all_metrics + ] + ) else: return [metric.get_score() for metric in self.all_metrics] def _save(self, phase, epoch): - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index 65258201..a0b0eeb2 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -4,7 +4,7 @@ # @FileName : maligan_instructor.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -14,8 +14,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.MaliGAN_D import MaliGAN_D -from models.MaliGAN_G import MaliGAN_G +from models.discriminators.MaliGAN_D import MaliGAN_D +from models.generators.MaliGAN_G import MaliGAN_G from utils.data_loader import GenDataIter, DisDataIter @@ -25,9 +25,17 @@ def __init__(self, opt): super(MaliGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = MaliGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = MaliGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = MaliGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = MaliGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -39,40 +47,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -84,16 +101,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): """ total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.get_mali_reward(target) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -129,12 +155,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/real_data/relgan_instructor.py b/instructor/real_data/relgan_instructor.py index 17df9ecc..22f338d7 100644 --- a/instructor/real_data/relgan_instructor.py +++ b/instructor/real_data/relgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : relgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -14,8 +14,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.RelGAN_D import RelGAN_D -from models.RelGAN_G import RelGAN_G +from models.discriminators.RelGAN_D import RelGAN_D +from models.generators.RelGAN_G import RelGAN_G from utils.helpers import get_fixed_temperature, get_losses @@ -24,10 +24,25 @@ def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, - gpu=cfg.CUDA) + self.gen = RelGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = RelGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -38,39 +53,50 @@ def __init__(self, opt): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pretrain_generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pretrain_generator: {}".format(cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator - self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature + self.update_temperature( + adv_epoch, cfg.ADV_train_epoch + ) # update temperature progress.set_description( - 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) + "g_loss: %.4f, d_loss: %.4f, temperature: %.4f" + % (g_loss, d_loss, self.gen.temperature) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( - adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s" + % (adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) progress.close() break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,23 +109,27 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): - real_samples = self.train_data.random_batch()['target'] + real_samples = self.train_data.random_batch()["target"] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -118,7 +148,7 @@ def adv_train_generator(self, g_step): def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): - real_samples = self.train_data.random_batch()['target'] + real_samples = self.train_data.random_batch()["target"] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -135,7 +165,9 @@ def adv_train_discriminator(self, d_step): return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): - self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) + self.gen.temperature = get_fixed_temperature( + cfg.temperature, i, N, cfg.temp_adpt + ) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): diff --git a/instructor/real_data/sentigan_instructor.py b/instructor/real_data/sentigan_instructor.py index 3d7ecd50..052b8bc0 100644 --- a/instructor/real_data/sentigan_instructor.py +++ b/instructor/real_data/sentigan_instructor.py @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.SentiGAN_D import SentiGAN_D, SentiGAN_C -from models.SentiGAN_G import SentiGAN_G +from models.discriminators.SentiGAN_D import SentiGAN_D, SentiGAN_C +from models.generators.SentiGAN_G import SentiGAN_G from utils import rollout from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter @@ -25,15 +25,39 @@ def __init__(self, opt): super(SentiGANInstructor, self).__init__(opt) # generator, discriminator - self.gen_list = [SentiGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - self.dis = SentiGAN_D(cfg.k_label, cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) - self.clas = SentiGAN_C(cfg.k_label, cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.extend_vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen_list = [ + SentiGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + self.dis = SentiGAN_D( + cfg.k_label, + cfg.dis_embed_dim, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.clas = SentiGAN_C( + cfg.k_label, + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.extend_vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer - self.gen_opt_list = [optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list] + self.gen_opt_list = [ + optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list + ] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) self.clas_opt = optim.Adam(self.clas.parameters(), lr=cfg.clas_lr) @@ -43,16 +67,35 @@ def __init__(self, opt): def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.k_label): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) self.gen_list[i].load_state_dict( - torch.load(cfg.pretrained_gen_path + '%d' % i, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.pretrained_gen_path + "%d" % i, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.clas_pretrain: - self.log.info('Load pretrained classifier: {}'.format(cfg.pretrained_clas_path)) - self.clas.load_state_dict(torch.load(cfg.pretrained_clas_path, map_location='cuda:%d' % cfg.device)) + self.log.info( + "Load pretrained classifier: {}".format(cfg.pretrained_clas_path) + ) + self.clas.load_state_dict( + torch.load( + cfg.pretrained_clas_path, map_location="cuda:%d" % cfg.device + ) + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -63,46 +106,62 @@ def init_model(self): def _run(self): # ===Pre-train Classifier with real data=== if cfg.use_clas_acc: - self.log.info('Start training Classifier...') + self.log.info("Start training Classifier...") self.train_classifier(cfg.PRE_clas_epoch) # ===PRE-TRAIN GENERATOR=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), cfg.pretrained_gen_path + '%d' % i) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen_list[i].state_dict(), + cfg.pretrained_gen_path + "%d" % i, + ) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s', self.comb_metrics(fmt_str=True)) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s", self.comb_metrics(fmt_str=True)) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -115,18 +174,24 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: for i in range(cfg.k_label): - pre_loss = self.train_gen_epoch(self.gen_list[i], self.train_data_list[i].loader, - self.mle_criterion, self.gen_opt_list[i]) + pre_loss = self.train_gen_epoch( + self.gen_list[i], + self.train_data_list[i].loader, + self.mle_criterion, + self.gen_opt_list[i], + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: if i == cfg.k_label - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -138,18 +203,23 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), + gpu=cfg.CUDA, + ) # ===Train=== - rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis, current_k=i) + rewards = rollout_func.get_reward( + target, cfg.rollout_num, self.dis, current_k=i + ) adv_loss = self.gen_list[i].batchPGLoss(inp, target, rewards) self.optimize(self.gen_opt_list[i], adv_loss) total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True)) + self.log.info("[ADV-GEN]: %s", self.comb_metrics(fmt_str=True)) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -163,37 +233,52 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): fake_samples = [] for i in range(cfg.k_label): real_samples.append(self.train_samples_list[i]) - fake_samples.append(self.gen_list[i].sample(cfg.samples_num // cfg.k_label, 8 * cfg.batch_size)) + fake_samples.append( + self.gen_list[i].sample( + cfg.samples_num // cfg.k_label, 8 * cfg.batch_size + ) + ) dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples dis_data = CatClasDataIter(dis_samples_list) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f" + % (phase, step, d_loss, train_acc) + ) - if cfg.if_save and not cfg.if_test and phase == 'MLE': + if cfg.if_save and not cfg.if_test and phase == "MLE": torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen_list[label_i].sample(cfg.samples_num, 8 * cfg.batch_size) + eval_samples = self.gen_list[label_i].sample( + cfg.samples_num, 8 * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen_list[label_i].sample(200, 200), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen_list[label_i].sample(200, 200), self.idx2word_dict + ) clas_data = CatClasDataIter([eval_samples], label_i) # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) - self.nll_gen.reset(self.gen_list[label_i], self.train_data_list[label_i].loader) + self.bleu.reset( + test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens + ) + self.nll_gen.reset( + self.gen_list[label_i], self.train_data_list[label_i].loader + ) self.nll_div.reset(self.gen_list[label_i], gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) @@ -204,9 +289,16 @@ def cal_metrics_with_label(self, label_i): def _save(self, phase, epoch): """Save model state dict and generator's samples""" for i in range(cfg.k_label): - if phase != 'ADV': - torch.save(self.gen_list[i].state_dict(), - cfg.save_model_root + 'gen{}_{}_{:05d}.pt'.format(i, phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_d{}_{}_{:05d}.txt'.format(i, phase, epoch) + if phase != "ADV": + torch.save( + self.gen_list[i].state_dict(), + cfg.save_model_root + "gen{}_{}_{:05d}.pt".format(i, phase, epoch), + ) + save_sample_path = ( + cfg.save_samples_root + + "samples_d{}_{}_{:05d}.txt".format(i, phase, epoch) + ) samples = self.gen_list[i].sample(cfg.batch_size, cfg.batch_size) - write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) + write_tokens( + save_sample_path, tensor_to_tokens(samples, self.idx2word_dict) + ) diff --git a/instructor/real_data/seqgan_instructor.py b/instructor/real_data/seqgan_instructor.py index 061a2607..2241a4d0 100644 --- a/instructor/real_data/seqgan_instructor.py +++ b/instructor/real_data/seqgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : seqgan_instructor.py # @Time : Created at 2019-06-05 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.SeqGAN_D import SeqGAN_D -from models.SeqGAN_G import SeqGAN_G +from models.discriminators.SeqGAN_D import SeqGAN_D +from models.generators.SeqGAN_G import SeqGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter @@ -23,9 +23,17 @@ def __init__(self, opt): super(SeqGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = SeqGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = SeqGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = SeqGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = SeqGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -37,40 +45,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -82,16 +99,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -128,12 +154,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/main.py b/main.py index 3ad65402..b882f37f 100644 --- a/main.py +++ b/main.py @@ -8,116 +8,155 @@ # Copyrights (C) 2018. All Rights Reserved. from __future__ import print_function +import random +import yaml + import argparse +# import torch +import numpy as np +import wandb + import config as cfg from utils.text_process import load_test_dict, text_process def program_config(parser): # Program - parser.add_argument('--if_test', default=cfg.if_test, type=int) - parser.add_argument('--run_model', default=cfg.run_model, type=str) - parser.add_argument('--k_label', default=cfg.k_label, type=int) - parser.add_argument('--dataset', default=cfg.dataset, type=str) - parser.add_argument('--model_type', default=cfg.model_type, type=str) - parser.add_argument('--loss_type', default=cfg.loss_type, type=str) - parser.add_argument('--mu_type', default=cfg.mu_type, type=str) - parser.add_argument('--eval_type', default=cfg.eval_type, type=str) - parser.add_argument('--d_type', default=cfg.d_type, type=str) - parser.add_argument('--if_real_data', default=cfg.if_real_data, type=int) - parser.add_argument('--cuda', default=cfg.CUDA, type=int) - parser.add_argument('--device', default=cfg.device, type=int) - parser.add_argument('--devices', default=cfg.devices, type=str) - parser.add_argument('--shuffle', default=cfg.data_shuffle, type=int) - parser.add_argument('--gen_init', default=cfg.gen_init, type=str) - parser.add_argument('--dis_init', default=cfg.dis_init, type=str) + parser.add_argument("--if_test", default=cfg.if_test, type=int) + parser.add_argument("--run_model", default=cfg.run_model, type=str) + parser.add_argument("--k_label", default=cfg.k_label, type=int) + parser.add_argument("--dataset", default=cfg.dataset, type=str) + parser.add_argument("--model_type", default=cfg.model_type, type=str) + parser.add_argument("--loss_type", default=cfg.loss_type, type=str) + parser.add_argument("--mu_type", default=cfg.mu_type, type=str) + parser.add_argument("--eval_type", default=cfg.eval_type, type=str) + parser.add_argument("--d_type", default=cfg.d_type, type=str) + parser.add_argument("--if_real_data", default=cfg.if_real_data, type=int) + parser.add_argument("--cuda", default=cfg.CUDA, type=int) + parser.add_argument("--device", default=cfg.device, type=int) + parser.add_argument("--devices", default=cfg.devices, type=str) + parser.add_argument("--shuffle", default=cfg.data_shuffle, type=int) + parser.add_argument("--gen_init", default=cfg.gen_init, type=str) + parser.add_argument("--dis_init", default=cfg.dis_init, type=str) # CatGAN - parser.add_argument('--n_parent', default=cfg.n_parent, type=int) - parser.add_argument('--eval_b_num', default=cfg.eval_b_num, type=int) - parser.add_argument('--lambda_fq', default=cfg.lambda_fq, type=float) - parser.add_argument('--lambda_fd', default=cfg.lambda_fd, type=float) - parser.add_argument('--d_out_mean', default=cfg.d_out_mean, type=int) - parser.add_argument('--freeze_dis', default=cfg.freeze_dis, type=int) - parser.add_argument('--freeze_clas', default=cfg.freeze_clas, type=int) - parser.add_argument('--use_all_real_fake', default=cfg.use_all_real_fake, type=int) - parser.add_argument('--use_population', default=cfg.use_population, type=int) + parser.add_argument("--n_parent", default=cfg.n_parent, type=int) + parser.add_argument("--eval_b_num", default=cfg.eval_b_num, type=int) + parser.add_argument("--lambda_fq", default=cfg.lambda_fq, type=float) + parser.add_argument("--lambda_fd", default=cfg.lambda_fd, type=float) + parser.add_argument("--d_out_mean", default=cfg.d_out_mean, type=int) + parser.add_argument("--freeze_dis", default=cfg.freeze_dis, type=int) + parser.add_argument("--freeze_clas", default=cfg.freeze_clas, type=int) + parser.add_argument("--use_all_real_fake", default=cfg.use_all_real_fake, type=int) + parser.add_argument("--use_population", default=cfg.use_population, type=int) + parser.add_argument("--batches_per_epoch", default=cfg.batches_per_epoch, type=int) + parser.add_argument("--noise_size", default=cfg.noise_size, type=int) + parser.add_argument("--max_epochs", default=cfg.max_epochs, type=int) + parser.add_argument("--target_len", default=cfg.target_len, type=int) # Basic Train - parser.add_argument('--samples_num', default=cfg.samples_num, type=int) - parser.add_argument('--vocab_size', default=cfg.vocab_size, type=int) - parser.add_argument('--mle_epoch', default=cfg.MLE_train_epoch, type=int) - parser.add_argument('--clas_pre_epoch', default=cfg.PRE_clas_epoch, type=int) - parser.add_argument('--adv_epoch', default=cfg.ADV_train_epoch, type=int) - parser.add_argument('--inter_epoch', default=cfg.inter_epoch, type=int) - parser.add_argument('--batch_size', default=cfg.batch_size, type=int) - parser.add_argument('--max_seq_len', default=cfg.max_seq_len, type=int) - parser.add_argument('--start_letter', default=cfg.start_letter, type=int) - parser.add_argument('--padding_idx', default=cfg.padding_idx, type=int) - parser.add_argument('--gen_lr', default=cfg.gen_lr, type=float) - parser.add_argument('--gen_adv_lr', default=cfg.gen_adv_lr, type=float) - parser.add_argument('--dis_lr', default=cfg.dis_lr, type=float) - parser.add_argument('--clip_norm', default=cfg.clip_norm, type=float) - parser.add_argument('--pre_log_step', default=cfg.pre_log_step, type=int) - parser.add_argument('--adv_log_step', default=cfg.adv_log_step, type=int) - parser.add_argument('--train_data', default=cfg.train_data, type=str) - parser.add_argument('--test_data', default=cfg.test_data, type=str) - parser.add_argument('--temp_adpt', default=cfg.temp_adpt, type=str) - parser.add_argument('--evo_temp_step', default=cfg.evo_temp_step, type=int) - parser.add_argument('--temperature', default=cfg.temperature, type=int) - parser.add_argument('--ora_pretrain', default=cfg.oracle_pretrain, type=int) - parser.add_argument('--gen_pretrain', default=cfg.gen_pretrain, type=int) - parser.add_argument('--dis_pretrain', default=cfg.dis_pretrain, type=int) + parser.add_argument("--samples_num", default=cfg.samples_num, type=int) + parser.add_argument("--vocab_size", default=cfg.vocab_size, type=int) + parser.add_argument("--mle_epoch", default=cfg.MLE_train_epoch, type=int) + parser.add_argument("--clas_pre_epoch", default=cfg.PRE_clas_epoch, type=int) + parser.add_argument("--adv_epoch", default=cfg.ADV_train_epoch, type=int) + parser.add_argument("--inter_epoch", default=cfg.inter_epoch, type=int) + parser.add_argument("--batch_size", default=cfg.batch_size, type=int) + parser.add_argument("--max_seq_len", default=cfg.max_seq_len, type=int) + parser.add_argument("--start_letter", default=cfg.start_letter, type=int) + parser.add_argument("--padding_idx", default=cfg.padding_idx, type=int) + parser.add_argument("--gen_lr", default=cfg.gen_lr, type=float) + parser.add_argument("--gen_adv_lr", default=cfg.gen_adv_lr, type=float) + parser.add_argument("--dis_lr", default=cfg.dis_lr, type=float) + parser.add_argument("--clip_norm", default=cfg.clip_norm, type=float) + parser.add_argument("--pre_log_step", default=cfg.pre_log_step, type=int) + parser.add_argument("--adv_log_step", default=cfg.adv_log_step, type=int) + parser.add_argument("--train_data", default=cfg.train_data, type=str) + parser.add_argument("--test_data", default=cfg.test_data, type=str) + parser.add_argument("--temp_adpt", default=cfg.temp_adpt, type=str) + parser.add_argument("--evo_temp_step", default=cfg.evo_temp_step, type=int) + parser.add_argument("--temperature", default=cfg.temperature, type=int) + parser.add_argument("--ora_pretrain", default=cfg.oracle_pretrain, type=int) + parser.add_argument("--gen_pretrain", default=cfg.gen_pretrain, type=int) + parser.add_argument("--dis_pretrain", default=cfg.dis_pretrain, type=int) # Generator - parser.add_argument('--adv_g_step', default=cfg.ADV_g_step, type=int) - parser.add_argument('--rollout_num', default=cfg.rollout_num, type=int) - parser.add_argument('--gen_embed_dim', default=cfg.gen_embed_dim, type=int) - parser.add_argument('--gen_hidden_dim', default=cfg.gen_hidden_dim, type=int) - parser.add_argument('--goal_size', default=cfg.goal_size, type=int) - parser.add_argument('--step_size', default=cfg.step_size, type=int) - parser.add_argument('--mem_slots', default=cfg.mem_slots, type=int) - parser.add_argument('--num_heads', default=cfg.num_heads, type=int) - parser.add_argument('--head_size', default=cfg.head_size, type=int) + parser.add_argument("--adv_g_step", default=cfg.ADV_g_step, type=int) + parser.add_argument("--rollout_num", default=cfg.rollout_num, type=int) + parser.add_argument("--gen_embed_dim", default=cfg.gen_embed_dim, type=int) + parser.add_argument("--gen_hidden_dim", default=cfg.gen_hidden_dim, type=int) + parser.add_argument("--goal_size", default=cfg.goal_size, type=int) + parser.add_argument("--step_size", default=cfg.step_size, type=int) + parser.add_argument("--mem_slots", default=cfg.mem_slots, type=int) + parser.add_argument("--num_heads", default=cfg.num_heads, type=int) + parser.add_argument("--head_size", default=cfg.head_size, type=int) + parser.add_argument( + "--generator_complexity", default=cfg.generator_complexity, type=int + ) # Discriminator - parser.add_argument('--d_step', default=cfg.d_step, type=int) - parser.add_argument('--d_epoch', default=cfg.d_epoch, type=int) - parser.add_argument('--adv_d_step', default=cfg.ADV_d_step, type=int) - parser.add_argument('--adv_d_epoch', default=cfg.ADV_d_epoch, type=int) - parser.add_argument('--dis_embed_dim', default=cfg.dis_embed_dim, type=int) - parser.add_argument('--dis_hidden_dim', default=cfg.dis_hidden_dim, type=int) - parser.add_argument('--num_rep', default=cfg.num_rep, type=int) + parser.add_argument("--d_step", default=cfg.d_step, type=int) + parser.add_argument("--d_epoch", default=cfg.d_epoch, type=int) + parser.add_argument("--adv_d_step", default=cfg.ADV_d_step, type=int) + parser.add_argument("--adv_d_epoch", default=cfg.ADV_d_epoch, type=int) + parser.add_argument("--dis_embed_dim", default=cfg.dis_embed_dim, type=int) + parser.add_argument("--dis_hidden_dim", default=cfg.dis_hidden_dim, type=int) + parser.add_argument("--num_rep", default=cfg.num_rep, type=int) + parser.add_argument( + "--discriminator_complexity", default=cfg.discriminator_complexity, type=int + ) + + # W2V embeddings + parser.add_argument( + "--w2v_embedding_size", default=cfg.w2v_embedding_size, type=int + ) + parser.add_argument("--w2v_window", default=cfg.w2v_window, type=int) + parser.add_argument("--w2v_min_count", default=cfg.w2v_min_count, type=int) + parser.add_argument("--w2v_workers", default=cfg.w2v_workers, type=int) + parser.add_argument("--w2v_samples_num", default=cfg.w2v_samples_num, type=int) # Metrics - parser.add_argument('--use_nll_oracle', default=cfg.use_nll_oracle, type=int) - parser.add_argument('--use_nll_gen', default=cfg.use_nll_gen, type=int) - parser.add_argument('--use_nll_div', default=cfg.use_nll_div, type=int) - parser.add_argument('--use_bleu', default=cfg.use_bleu, type=int) - parser.add_argument('--use_self_bleu', default=cfg.use_self_bleu, type=int) - parser.add_argument('--use_clas_acc', default=cfg.use_clas_acc, type=int) - parser.add_argument('--use_ppl', default=cfg.use_ppl, type=int) + parser.add_argument("--use_nll_oracle", default=cfg.use_nll_oracle, type=int) + parser.add_argument("--use_nll_gen", default=cfg.use_nll_gen, type=int) + parser.add_argument("--use_nll_div", default=cfg.use_nll_div, type=int) + parser.add_argument("--use_bleu", default=cfg.use_bleu, type=int) + parser.add_argument("--use_self_bleu", default=cfg.use_self_bleu, type=int) + parser.add_argument("--use_clas_acc", default=cfg.use_clas_acc, type=int) + parser.add_argument("--use_ppl", default=cfg.use_ppl, type=int) # Log - parser.add_argument('--log_file', default=cfg.log_filename, type=str) - parser.add_argument('--save_root', default=cfg.save_root, type=str) - parser.add_argument('--signal_file', default=cfg.signal_file, type=str) - parser.add_argument('--tips', default=cfg.tips, type=str) - + parser.add_argument("--log_file", default=cfg.log_filename, type=str) + parser.add_argument("--save_root", default=cfg.save_root, type=str) + parser.add_argument("--signal_file", default=cfg.signal_file, type=str) + parser.add_argument("--tips", default=cfg.tips, type=str) + + # Loss coefficients + parser.add_argument("--real_fake_coeff", default=1.0, type=float) + parser.add_argument("--labels_coeff", default=1.0, type=float) + parser.add_argument("--diversity_coeff", default=1.0, type=float) return parser # MAIN -if __name__ == '__main__': +if __name__ == "__main__": + # seed everything + # torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + # Hyper Parameters parser = argparse.ArgumentParser() parser = program_config(parser) opt = parser.parse_args() if opt.if_real_data: - opt.max_seq_len, opt.vocab_size = text_process('dataset/' + opt.dataset + '.txt') - cfg.extend_vocab_size = len(load_test_dict(opt.dataset)[0]) # init classifier vocab_size + opt.max_seq_len, opt.vocab_size = text_process( + "dataset/" + opt.dataset + ".txt" + ) + cfg.extend_vocab_size = len( + load_test_dict(opt.dataset)[0] + ) # init classifier vocab_size cfg.init_param(opt) opt.save_root = cfg.save_root opt.train_data = cfg.train_data @@ -136,6 +175,7 @@ def program_config(parser): from instructor.real_data.catgan_instructor import CatGANInstructor from instructor.real_data.dgsan_instructor import DGSANInstructor from instructor.real_data.cot_instructor import CoTInstructor + from instructor.real_data.fixem_instructor import FixemGANInstructor else: from instructor.oracle_data.seqgan_instructor import SeqGANInstructor @@ -149,23 +189,46 @@ def program_config(parser): from instructor.oracle_data.catgan_instructor import CatGANInstructor from instructor.oracle_data.dgsan_instructor import DGSANInstructor from instructor.oracle_data.cot_instructor import CoTInstructor + from instructor.oracle_data.fixem_instructor import FixemGANInstructor instruction_dict = { - 'seqgan': SeqGANInstructor, - 'leakgan': LeakGANInstructor, - 'maligan': MaliGANInstructor, - 'jsdgan': JSDGANInstructor, - 'dpgan': DPGANInstructor, - 'relgan': RelGANInstructor, - 'sentigan': SentiGANInstructor, - 'evogan': EvoGANInstructor, - 'catgan': CatGANInstructor, - 'dgsan': DGSANInstructor, - 'cot': CoTInstructor, + "seqgan": SeqGANInstructor, + "leakgan": LeakGANInstructor, + "maligan": MaliGANInstructor, + "jsdgan": JSDGANInstructor, + "dpgan": DPGANInstructor, + "relgan": RelGANInstructor, + "sentigan": SentiGANInstructor, + "evogan": EvoGANInstructor, + "catgan": CatGANInstructor, + "dgsan": DGSANInstructor, + "cot": CoTInstructor, + "fixemgan": FixemGANInstructor, + "cat_fixemgan": FixemGANInstructor, } - inst = instruction_dict[cfg.run_model](opt) - if not cfg.if_test: + # Example sweep configuration + with open("sweep.yml") as sweep_yml: + sweep_configuration = yaml.safe_load(sweep_yml) + print("sweep_configuration", sweep_configuration) + + sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") + # sweep_id = "7g6po2bd" + print("sweep_id", sweep_id) + + def full_train_run(opt): + inst = instruction_dict[cfg.run_model](opt) inst._run() - else: - inst._test() + + def function_for_parameters_sweep(): + run = wandb.init() # Initialize a new wandb run + config = run.config # Get the config dictionary for the current run + print("config", config) + + # Update 'opt' with the hyperparameters from 'config' + for name, value in config.items(): + setattr(opt, name, value) + full_train_run(opt) + run.finish() # Make sure to finish the run + + wandb.agent(sweep_id=sweep_id, function=function_for_parameters_sweep) diff --git a/metrics/basic.py b/metrics/basic.py index b96377c0..6df50188 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -4,26 +4,41 @@ # @FileName : basic.py # @Time : Created at 2019-05-14 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from abc import abstractmethod class Metrics: - def __init__(self, name='Metric'): + def __init__(self, name, weight, if_use): self.name = name + # represents effect on final score + # ex.: self-bleu has weight = -1 (less is better) + # bleu has weight = 1 (more is better) + # weights needed for combined metric evaluation + self.weight = weight + self.if_use = if_use + self.metric_value_with_current_state = None - def get_name(self): - return self.name + def get_score(self): + if not self.if_use: + return 0 - def set_name(self, name): - self.name = name + if self.metric_value_with_current_state is not None: + return self.metric_value_with_current_state + + self.metric_value_with_current_state = self.calculate_metric() + return self.metric_value_with_current_state + + def reset(self, *args, **kwargs): + self.metric_value_with_current_state = None + self._reset(*args, **kwargs) @abstractmethod - def get_score(self): + def calculate_metric(self): pass @abstractmethod - def reset(self): + def _reset(self): pass diff --git a/metrics/bleu.py b/metrics/bleu.py index 3cc7d5d4..a8e0e6fe 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -4,102 +4,106 @@ # @FileName : bleu.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from multiprocessing import Pool -import nltk import os import random + +import nltk from nltk.translate.bleu_score import SmoothingFunction +from tqdm import tqdm from metrics.basic import Metrics class BLEU(Metrics): - def __init__(self, name=None, test_text=None, real_text=None, gram=3, portion=1, if_use=False): - assert type(gram) == int or type(gram) == list, 'Gram format error!' - super(BLEU, self).__init__('%s-%s' % (name, gram)) + """ + Get BLEU scores. + :param is_fast: Fast mode + :param given_gram: Calculate specific n-gram BLEU score + """ + + def __init__( + self, + name=None, + weight=1, + test_text=None, + real_text=None, + gram=3, + portion=1, + if_use=False, + ): + assert type(gram) == int or type(gram) == list, "Gram format error!" + super(BLEU, self).__init__("%s-%s" % (name, gram), weight, if_use) self.if_use = if_use self.test_text = test_text self.real_text = real_text - self.gram = [gram] if type(gram) == int else gram - self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 - self.reference = None - self.is_first = True + self.gram = gram if type(gram) == int else gram + self.sample_size = ( + 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 + ) self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset - def get_score(self, is_fast=True, given_gram=None): - """ - Get BLEU scores. - :param is_fast: Fast mode - :param given_gram: Calculate specific n-gram BLEU score - """ - if not self.if_use: - return 0 - if self.is_first: - self.get_reference() - self.is_first = False - if is_fast: - return self.get_bleu_fast(given_gram) - return self.get_bleu(given_gram) - - def reset(self, test_text=None, real_text=None): - self.test_text = test_text if test_text else self.test_text - self.real_text = real_text if real_text else self.real_text + def _reset(self, test_text=None, real_text=None): + self.test_text = test_text if test_text is not None else self.test_text + self.real_text = real_text if real_text is not None else self.real_text def get_reference(self): reference = self.real_text.copy() - # randomly choose a portion of test data # In-place shuffle random.shuffle(reference) len_ref = len(reference) - reference = reference[:int(self.portion * len_ref)] - self.reference = reference + reference = reference[: int(self.portion * len_ref)] return reference def get_bleu(self, given_gram=None): - if given_gram is not None: # for single gram - bleu = list() - reference = self.get_reference() - weight = tuple((1. / given_gram for _ in range(given_gram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - bleu.append(self.cal_bleu(reference, hypothesis, weight)) - return round(sum(bleu) / len(bleu), 3) - else: # for multiple gram - all_bleu = [] - for ngram in self.gram: - bleu = list() - reference = self.get_reference() - weight = tuple((1. / ngram for _ in range(ngram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - bleu.append(self.cal_bleu(reference, hypothesis, weight)) - all_bleu.append(round(sum(bleu) / len(bleu), 3)) - return all_bleu + if type(self.gram) == int: # for single gram + return self.get_blue_for_single_gram(self.gram) + # for multiple gram + all_bleu = [] + for ngram in self.gram: + all_bleu.append(self.get_blue_for_single_gram(ngram)) + return all_bleu + + def get_blue_for_single_gram(self, ngram): + bleu = list() + reference = self.get_reference() + weight = tuple((1.0 / ngram for _ in range(ngram))) + for idx, hypothesis in enumerate(self.test_text[: self.sample_size]): + bleu.append(self.cal_bleu(reference, hypothesis, weight)) + return round(sum(bleu) / len(bleu), 3) @staticmethod def cal_bleu(reference, hypothesis, weight): - return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, - smoothing_function=SmoothingFunction().method1) + return nltk.translate.bleu_score.sentence_bleu( + reference, + hypothesis, + weight, + smoothing_function=SmoothingFunction().method1, + ) - def get_bleu_fast(self, given_gram=None): + def calculate_metric(self): + if type(self.gram) == int: # for single gram + return self.get_blue_for_single_gram(self.gram) + # for multiple gram reference = self.get_reference() - if given_gram is not None: # for single gram - return self.get_bleu_parallel(ngram=given_gram, reference=reference) - else: # for multiple gram - all_bleu = [] - for ngram in self.gram: - all_bleu.append(self.get_bleu_parallel(ngram=ngram, reference=reference)) - return all_bleu + all_bleu = [] + for ngram in self.gram: + all_bleu.append(self.get_bleu_parallel(ngram=ngram, reference=reference)) + return all_bleu def get_bleu_parallel(self, ngram, reference): - weight = tuple((1. / ngram for _ in range(ngram))) + weight = tuple((1.0 / ngram for _ in range(ngram))) pool = Pool(os.cpu_count()) result = list() - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - result.append(pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight))) + for idx, hypothesis in enumerate(self.test_text[: self.sample_size]): + result.append( + pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight)) + ) score = 0.0 cnt = 0 for i in result: diff --git a/metrics/clas_acc.py b/metrics/clas_acc.py index 01877055..bed1f594 100644 --- a/metrics/clas_acc.py +++ b/metrics/clas_acc.py @@ -4,7 +4,7 @@ # @FileName : clas_acc.py # @Time : Created at 2019/12/4 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -13,35 +13,28 @@ class ACC(Metrics): - def __init__(self, if_use=True, gpu=True): - super(ACC, self).__init__('clas_acc') + def __init__(self, weight, if_use=True, gpu=True): + super(ACC, self).__init__("clas_acc", weight, if_use) self.if_use = if_use self.model = None self.data_loader = None self.gpu = gpu - def get_score(self): - if not self.if_use: - return 0 - assert self.model and self.data_loader, 'Need to reset() before get_score()!' - - return self.cal_acc(self.model, self.data_loader) - - def reset(self, model=None, data_loader=None): + def _reset(self, model=None, data_loader=None): self.model = model self.data_loader = data_loader - def cal_acc(self, model, data_loader): + def calculate_metric(self, model, data_loader): total_acc = 0 total_num = 0 with torch.no_grad(): - for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + for i, data in enumerate(self.data_loader): + inp, target = data["input"], data["target"] if self.gpu: inp, target = inp.cuda(), target.cuda() - pred = model.forward(inp) + pred = self.model.forward(inp) total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) return round(total_acc / total_num, 4) diff --git a/metrics/dummy.py b/metrics/dummy.py new file mode 100644 index 00000000..5e1a07d6 --- /dev/null +++ b/metrics/dummy.py @@ -0,0 +1,14 @@ +from metrics.basic import Metrics + + +class Dummy(Metrics): + """ + Dummy score to make Overal score positive and easy to read + """ + + def __init__(self, name=None, weight=1, value=5, if_use=True): + super(Dummy, self).__init__("Dummy", weight, if_use) + self.value = 5 + + def calculate_metric(self): + return self.value diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py new file mode 100644 index 00000000..47761995 --- /dev/null +++ b/metrics/gpt_nll.py @@ -0,0 +1,61 @@ +from collections import Counter +from itertools import chain +import os +import random + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import GPT2LMHeadModel, GPT2Tokenizer +from tqdm import tqdm + +from metrics.basic import Metrics + + +class GPTNLL(Metrics): + def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): + super(GPTNLL, self).__init__("GPT2 as oracle", weight, if_use) + + self.if_use = if_use + self.test_text = test_text + + self.NLLloss = torch.nn.NLLLoss() + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.model = GPT2LMHeadModel.from_pretrained("gpt2") + print("Calculating dataset NLL") + self.real_text_nll = ( + self.calcualte_NLL(random.sample(real_text.tokens, 500)) + if real_text + else None + ) + if self.real_text_nll: + print(f"dataset NLL based on GPT2 is {self.real_text_nll}") + print("GPT2 as oracle metric will be calculated relative to this value") + + def _reset(self, test_text=None, real_text=None): + self.test_text = test_text if test_text is not None else self.test_text + self.real_text_nll = ( + self.calcualte_NLL(real_text.tokens) + if real_text is not None + else self.real_text_nll + ) + + def calculate_metric(self): + """Get gpt2 NLL score difference with dataset NLL.""" + return self.calcualte_NLL(self.test_text) - self.real_text_nll + + def calcualte_NLL(self, messages): + if type(messages[0]) == list: # we received list of tokens + messages = [" ".join(msg) for msg in messages] + + all_logits = [] + for message in messages: + message = self.tokenizer.eos_token + message + self.tokenizer.eos_token + inputs = self.tokenizer(message, return_tensors="pt") + logits = self.model(**inputs)[0][0] + logits = F.log_softmax(logits, dim=1) + # calculating NLL loss on token appearing on it's position + all_logits.append( + self.NLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() + ) + return np.mean(all_logits) diff --git a/metrics/ioc.py b/metrics/ioc.py new file mode 100644 index 00000000..4cf3f81b --- /dev/null +++ b/metrics/ioc.py @@ -0,0 +1,41 @@ +from collections import Counter +from itertools import chain + +import nltk +import os +import random + +from metrics.basic import Metrics + + +class IOC(Metrics): + def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): + super(IOC, self).__init__("Index of Coincidence", weight, if_use) + + self.if_use = if_use + self.test_text = test_text + self.real_text_ioc = self.calculate_ioc(real_text.tokens) if real_text else None + if self.real_text_ioc: + print(f"Dataset Index of coincidence: {self.real_text_ioc}") + self.reference = None + self.is_first = True + + def _reset(self, test_text=None, real_text=None): + self.test_text = test_text if test_text is not None else self.test_text + self.real_text_ioc = ( + self.get_ioc(real_text.tokens) + if real_text is not None + else self.real_text_ioc + ) + + def calculate_metric(self): + return self.calculate_ioc(self.test_text) / self.real_text_ioc + + def calculate_ioc(self, tokenized_text): + """Index Of coincidence: probability of 2 random tokens in text to equal.""" + tokenized_text = [[str(token) for token in tokens] for tokens in tokenized_text] + tokens = list(chain(*tokenized_text)) + counts = Counter(tokens) + total = sum(ni * (ni - 1) for ni in counts.values()) + N = len(tokens) + return total / N / (N - 1) diff --git a/metrics/nll.py b/metrics/nll.py index a4d09983..78f19f4d 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -4,7 +4,7 @@ # @FileName : nll.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -15,9 +15,8 @@ class NLL(Metrics): - def __init__(self, name, if_use=False, gpu=False): - super(NLL, self).__init__(name) - + def __init__(self, name, weight, if_use=False, gpu=False): + super(NLL, self).__init__(name, weight, if_use) self.if_use = if_use self.model = None self.data_loader = None @@ -26,21 +25,19 @@ def __init__(self, name, if_use=False, gpu=False): self.gpu = gpu self.criterion = nn.NLLLoss() - def get_score(self): + def calculate_metric(self): """note that NLL score need the updated model and data loader each time, use reset() before get_score()""" - if not self.if_use: - return 0 - assert self.model and self.data_loader, 'Need to reset() before get_score()!' - if self.leak_dis is not None: # For LeakGAN - return self.cal_nll_with_leak_dis(self.model, self.data_loader, self.leak_dis, self.gpu) - elif self.label_i is not None: # For category text generation - return self.cal_nll_with_label(self.model, self.data_loader, self.label_i, - self.criterion, self.gpu) - else: - return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) + return self.cal_nll_with_leak_dis( + self.model, self.data_loader, self.leak_dis, self.gpu + ) + if self.label_i is not None: # For category text generation + return self.cal_nll_with_label( + self.model, self.data_loader, self.label_i, self.criterion, self.gpu + ) + return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) - def reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): + def _reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): self.model = model self.data_loader = data_loader self.label_i = label_i @@ -52,7 +49,7 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if gpu: inp, target = inp.cuda(), target.cuda() @@ -65,17 +62,17 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): @staticmethod def cal_nll_with_label(model, data_loader, label_i, criterion, gpu=cfg.CUDA): """NLL score for category text generation model.""" - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] label = torch.LongTensor([label_i] * data_loader.batch_size) if gpu: inp, target, label = inp.cuda(), target.cuda(), label.cuda() hidden = model.init_hidden(data_loader.batch_size) - if model.name == 'oracle': + if model.name == "oracle": pred = model.forward(inp, hidden) else: pred = model.forward(inp, hidden, label) @@ -89,7 +86,7 @@ def cal_nll_with_leak_dis(model, data_loader, leak_dis, gpu=cfg.CUDA): total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if gpu: inp, target = inp.cuda(), target.cuda() diff --git a/metrics/ppl.py b/metrics/ppl.py index a1049a7d..0fc986af 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -4,7 +4,7 @@ # @FileName : ppl.py # @Time : Created at 2019/12/5 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import math import string @@ -17,11 +17,11 @@ from metrics.basic import Metrics from utils.text_process import write_tokens -kenlm_path = '/home/zhiwei/kenlm' # specify the kenlm path +kenlm_path = "/home/zhiwei/kenlm" # specify the kenlm path class PPL(Metrics): - def __init__(self, train_data, test_data, n_gram=5, if_use=False): + def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False): """ Calculate Perplexity scores, including forward and reverse. PPL-F: PPL_forward, PPL-R: PPL_reverse @@ -30,7 +30,7 @@ def __init__(self, train_data, test_data, n_gram=5, if_use=False): @param n_gram: calculate with n-gram @param if_use: if use """ - super(PPL, self).__init__('[PPL-F, PPL-R]') + super(PPL, self).__init__("[PPL-F, PPL-R]", weight, if_use) self.n_gram = n_gram self.if_use = if_use @@ -39,30 +39,37 @@ def __init__(self, train_data, test_data, n_gram=5, if_use=False): self.train_data = train_data self.test_data = test_data - def get_score(self): - if not self.if_use: - return 0 - return self.cal_ppl() - - def reset(self, gen_tokens=None): + def _reset(self, gen_tokens=None): self.gen_tokens = gen_tokens - def cal_ppl(self): - save_path = os.path.join("/tmp", ''.join(random.choice( - string.ascii_uppercase + string.digits) for _ in range(6))) + def calculate_metric(self): + save_path = os.path.join( + "/tmp", + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(6) + ), + ) output_path = save_path + ".arpa" write_tokens(save_path, self.gen_tokens) # save to file # forward ppl - for_lm = self.train_ngram_lm(kenlm_path=kenlm_path, data_path=cfg.test_data, - output_path=output_path, n_gram=self.n_gram) + for_lm = self.train_ngram_lm( + kenlm_path=kenlm_path, + data_path=cfg.test_data, + output_path=output_path, + n_gram=self.n_gram, + ) for_ppl = self.get_ppl(for_lm, self.gen_tokens) # reverse ppl try: - rev_lm = self.train_ngram_lm(kenlm_path=kenlm_path, data_path=save_path, - output_path=output_path, n_gram=self.n_gram) + rev_lm = self.train_ngram_lm( + kenlm_path=kenlm_path, + data_path=save_path, + output_path=output_path, + n_gram=self.n_gram, + ) rev_ppl = self.get_ppl(rev_lm, self.test_data.tokens) except: @@ -81,14 +88,21 @@ def train_ngram_lm(self, kenlm_path, data_path, output_path, n_gram): # create .arpa and .bin file of n-grams curdir = os.path.abspath(os.path.curdir) - cd_command = "cd " + os.path.join(kenlm_path, 'build') - command_1 = "bin/lmplz -o {} <{} >{} --discount_fallback &".format(str(n_gram), os.path.join(curdir, data_path), - output_path) - command_2 = "bin/build_binary -s {} {} &".format(output_path, output_path + ".bin") + cd_command = "cd " + os.path.join(kenlm_path, "build") + command_1 = "bin/lmplz -o {} <{} >{} --discount_fallback &".format( + str(n_gram), os.path.join(curdir, data_path), output_path + ) + command_2 = "bin/build_binary -s {} {} &".format( + output_path, output_path + ".bin" + ) while True: - subprocess.getstatusoutput(cd_command + " && " + command_1) # call without logging output - subprocess.getstatusoutput(cd_command + " && " + command_2) # call without logging output + subprocess.getstatusoutput( + cd_command + " && " + command_1 + ) # call without logging output + subprocess.getstatusoutput( + cd_command + " && " + command_2 + ) # call without logging output if os.path.exists(output_path + ".bin"): break @@ -104,8 +118,14 @@ def get_ppl(self, lm, tokens): total_nll = 0 total_wc = 0 for words in tokens: - nll = np.sum([-math.log(math.pow(10.0, score)) - for score, _, _ in lm.full_scores(' '.join(words), bos=True, eos=False)]) + nll = np.sum( + [ + -math.log(math.pow(10.0, score)) + for score, _, _ in lm.full_scores( + " ".join(words), bos=True, eos=False + ) + ] + ) total_wc += len(words) total_nll += nll ppl = np.exp(total_nll / total_wc) diff --git a/models/CatGAN_D.py b/models/CatGAN_D.py deleted file mode 100644 index d9f5310b..00000000 --- a/models/CatGAN_D.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : William -# @Project : TextGAN-william -# @FileName : CatGAN_D.py -# @Time : Created at 2019-05-28 -# @Blog : http://zhiweil.ml/ -# @Description : -# Copyrights (C) 2018. All Rights Reserved. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.discriminator import CNNDiscriminator, CNNClassifier - -dis_filter_sizes = [2, 3, 4, 5] -dis_num_filters = [300, 300, 300, 300] -clas_filter_sizes = [2, 3, 4, 5] -clas_num_filters = [200] - - -# Discriminator -class CatGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(CatGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) - - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - self.feature_dim = sum(dis_num_filters) - self.emb_dim_single = int(embed_dim / num_rep) - - self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) - - self.highway = nn.Linear(self.feature_dim, self.feature_dim) - self.feature2out = nn.Linear(self.feature_dim, 100) # origin - self.out2logits = nn.Linear(100, 1) # origin - self.dropout = nn.Dropout(dropout) - - self.init_params() - - def forward(self, inp): - """ - Get logits of discriminator - :param inp: batch_size * seq_len * vocab_size - :return logits: [batch_size * num_rep] (1-D tensor) - """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim - - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] - - pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim - highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim - - pred = self.feature2out(self.dropout(pred)) - logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] - - return logits - - -# Classifier -class CatGAN_C(CNNClassifier): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(CatGAN_C, self).__init__(k_label, embed_dim, max_seq_len, num_rep, vocab_size, clas_filter_sizes, - clas_num_filters, padding_idx, gpu, dropout) - - # Use Glove - # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/DGSAN_G.py b/models/DGSAN_G.py deleted file mode 100644 index 40646e0b..00000000 --- a/models/DGSAN_G.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : William -# @Project : TextGAN-william -# @FileName : DGSAN_G.py -# @Time : Created at 2020/4/12 -# @Blog : http://zhiweil.ml/ -# @Description : -# Copyrights (C) 2018. All Rights Reserved. - -from models.generator import LSTMGenerator - - -class DGSAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DGSAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dgsan' diff --git a/models/EvoGAN_D.py b/models/EvoGAN_D.py deleted file mode 100644 index 6904ca7c..00000000 --- a/models/EvoGAN_D.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : William -# @Project : TextGAN-william -# @FileName : EvoGAN_D.py -# @Time : Created at 2019-07-09 -# @Blog : http://zhiweil.ml/ -# @Description : -# Copyrights (C) 2018. All Rights Reserved. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.discriminator import CNNDiscriminator - -dis_filter_sizes = [2, 3, 4, 5] -dis_num_filters = [300, 300, 300, 300] - - -class EvoGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(EvoGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) - - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - self.feature_dim = sum(dis_num_filters) - self.emb_dim_single = int(embed_dim / num_rep) - - self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) - - self.highway = nn.Linear(self.feature_dim, self.feature_dim) - self.feature2out = nn.Linear(self.feature_dim, 100) # origin - self.out2logits = nn.Linear(100, 1) # origin - self.dropout = nn.Dropout(dropout) - - self.init_params() - - def forward(self, inp): - """ - Get logits of discriminator - :param inp: batch_size * seq_len * vocab_size - :return logits: [batch_size * num_rep] (1-D tensor) - """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim - - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] - - pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim - highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim - - pred = self.feature2out(self.dropout(pred)) - logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] - - return logits diff --git a/models/Oracle.py b/models/Oracle.py index 0ea46c8e..2b0a2a6f 100644 --- a/models/Oracle.py +++ b/models/Oracle.py @@ -4,18 +4,28 @@ # @FileName : Oracle.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.generator import LSTMGenerator +import torch +from models.generators.generator import LSTMGenerator -class Oracle(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(Oracle, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'oracle' +class Oracle(LSTMGenerator): + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(Oracle, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "oracle" # initialise oracle network with N(0,1) # otherwise variance of initialisation is very small => high NLL for loader sampled from the same model self.init_oracle() + + def init_oracle(self): + for param in self.parameters(): + if param.requires_grad: + torch.nn.init.normal_(param, mean=0, std=1) diff --git a/models/RelGAN_D.py b/models/RelGAN_D.py deleted file mode 100644 index b62334d6..00000000 --- a/models/RelGAN_D.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : William -# @Project : TextGAN-william -# @FileName : RelGAN_D.py -# @Time : Created at 2019-04-25 -# @Blog : http://zhiweil.ml/ -# @Description : -# Copyrights (C) 2018. All Rights Reserved. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.discriminator import CNNDiscriminator - -dis_filter_sizes = [2, 3, 4, 5] -dis_num_filters = [300, 300, 300, 300] - - -class RelGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(RelGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) - - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - self.feature_dim = sum(dis_num_filters) - self.emb_dim_single = int(embed_dim / num_rep) - - self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) - - self.highway = nn.Linear(self.feature_dim, self.feature_dim) - self.feature2out = nn.Linear(self.feature_dim, 100) - self.out2logits = nn.Linear(100, 1) - self.dropout = nn.Dropout(dropout) - - self.init_params() - - def forward(self, inp): - """ - Get logits of discriminator - :param inp: batch_size * seq_len * vocab_size - :return logits: [batch_size * num_rep] (1-D tensor) - """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim - - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] - pred = torch.cat(pools, 1) - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim - highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway - - pred = self.feature2out(self.dropout(pred)) - logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] - - return logits diff --git a/models/SentiGAN_D.py b/models/SentiGAN_D.py deleted file mode 100644 index 7daccfb7..00000000 --- a/models/SentiGAN_D.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : William -# @Project : TextGAN-william -# @FileName : SentiGAN_D.py -# @Time : Created at 2019-07-26 -# @Blog : http://zhiweil.ml/ -# @Description : -# Copyrights (C) 2018. All Rights Reserved. - -import torch.nn as nn - -from models.discriminator import CNNDiscriminator, CNNClassifier - -dis_filter_sizes = [2, 3, 4, 5] -dis_num_filters = [200, 200, 200, 200] - -clas_filter_sizes = [2, 3, 4, 5] -clas_num_filters = [200] - - -class SentiGAN_D(CNNDiscriminator): - def __init__(self, k_label, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2): - super(SentiGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) - - self.feature2out = nn.Linear(self.feature_dim, k_label + 1) - - self.init_params() - - -# Classifier -class SentiGAN_C(CNNClassifier): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(SentiGAN_C, self).__init__(k_label, embed_dim, max_seq_len, num_rep, vocab_size, clas_filter_sizes, - clas_num_filters, padding_idx, gpu, dropout) - - # Use Glove - # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/discriminators/CatGAN_D.py b/models/discriminators/CatGAN_D.py new file mode 100644 index 00000000..1e977c6c --- /dev/null +++ b/models/discriminators/CatGAN_D.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : CatGAN_D.py +# @Time : Created at 2019-05-28 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.discriminators.discriminator import CNNDiscriminator, CNNClassifier + +dis_filter_sizes = [2, 3, 4, 5] +dis_num_filters = [300, 300, 300, 300] +clas_filter_sizes = [2, 3, 4, 5] +clas_num_filters = [200] + + +# Discriminator +class CatGAN_D(CNNDiscriminator): + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CatGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) + + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + self.feature_dim = sum(dis_num_filters) + self.emb_dim_single = int(embed_dim / num_rep) + + self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) + + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) + + self.highway = nn.Linear(self.feature_dim, self.feature_dim) + self.feature2out = nn.Linear(self.feature_dim, 100) # origin + self.out2logits = nn.Linear(100, 1) # origin + self.dropout = nn.Dropout(dropout) + + self.init_params() + + def forward(self, inp): + """ + Get logits of discriminator + :param inp: batch_size * seq_len * vocab_size + :return logits: [batch_size * num_rep] (1-D tensor) + """ + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim + + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] + + pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim + highway = self.highway(pred) + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim + + pred = self.feature2out(self.dropout(pred)) + logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] + + return logits + + +# Classifier +class CatGAN_C(CNNClassifier): + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CatGAN_C, self).__init__( + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + clas_filter_sizes, + clas_num_filters, + padding_idx, + gpu, + dropout, + ) + + # Use Glove + # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/CoT_D.py b/models/discriminators/CoT_D.py similarity index 64% rename from models/CoT_D.py rename to models/discriminators/CoT_D.py index 325c3adc..0eab5e3b 100644 --- a/models/CoT_D.py +++ b/models/discriminators/CoT_D.py @@ -4,18 +4,22 @@ # @FileName : CoT_Medicator.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class Cot_D(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(Cot_D, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(Cot_D, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) def get_pred(self, input, target): pred = self.forward(input, self.init_hidden(input.size(0))) diff --git a/models/DPGAN_D.py b/models/discriminators/DPGAN_D.py similarity index 70% rename from models/DPGAN_D.py rename to models/discriminators/DPGAN_D.py index 4b8827af..77c9a527 100644 --- a/models/DPGAN_D.py +++ b/models/discriminators/DPGAN_D.py @@ -11,14 +11,18 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from utils.data_loader import GenDataIter class DPGAN_D(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DPGAN_D, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dpgan_d' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DPGAN_D, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dpgan_d" def getReward(self, samples): """ @@ -30,7 +34,9 @@ def getReward(self, samples): hidden = self.init_hidden(batch_size) pred = self.forward(inp, hidden) - word_reward = F.nll_loss(pred, target.view(-1), reduction='none').view(batch_size, -1) + word_reward = F.nll_loss(pred, target.view(-1), reduction="none").view( + batch_size, -1 + ) sentence_reward = torch.mean(word_reward, dim=-1, keepdim=True) return word_reward, sentence_reward diff --git a/models/discriminators/EvoGAN_D.py b/models/discriminators/EvoGAN_D.py new file mode 100644 index 00000000..6f64f8e6 --- /dev/null +++ b/models/discriminators/EvoGAN_D.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : EvoGAN_D.py +# @Time : Created at 2019-07-09 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.discriminators.discriminator import CNNDiscriminator + +dis_filter_sizes = [2, 3, 4, 5] +dis_num_filters = [300, 300, 300, 300] + + +class EvoGAN_D(CNNDiscriminator): + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(EvoGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) + + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + self.feature_dim = sum(dis_num_filters) + self.emb_dim_single = int(embed_dim / num_rep) + + self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) + + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) + + self.highway = nn.Linear(self.feature_dim, self.feature_dim) + self.feature2out = nn.Linear(self.feature_dim, 100) # origin + self.out2logits = nn.Linear(100, 1) # origin + self.dropout = nn.Dropout(dropout) + + self.init_params() + + def forward(self, inp): + """ + Get logits of discriminator + :param inp: batch_size * seq_len * vocab_size + :return logits: [batch_size * num_rep] (1-D tensor) + """ + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim + + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] + + pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim + highway = self.highway(pred) + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim + + pred = self.feature2out(self.dropout(pred)) + logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] + + return logits diff --git a/models/discriminators/FixemGAN_D.py b/models/discriminators/FixemGAN_D.py new file mode 100644 index 00000000..dd0a6f6f --- /dev/null +++ b/models/discriminators/FixemGAN_D.py @@ -0,0 +1,88 @@ +import torch.nn as nn + +import config as cfg +from utils.nn_helpers import ( + get_optimizer, + MyConvLayer, + MyTransformerEncoderLayer, + Flatten, + Dummy, +) +from models.discriminators.discriminator import CNNDiscriminator + + +class Discriminator(nn.Module): + def __init__(self, complexity): + super(Discriminator, self).__init__() + alpha = 0.2 + drop_rate = 0.0 + include_transformer = False + + self.main = nn.Sequential( + # 1 layer + MyConvLayer( + cfg.w2v_embedding_size, complexity, alpha=alpha, drop_rate=drop_rate + ), + # 2 layer + MyConvLayer( + complexity, + complexity, + alpha=alpha, + drop_rate=drop_rate, + ), + # 3 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # MyLSTMLayer(complexity, complexity//2), + # 4 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # 5 layer + MyTransformerEncoderLayer( + d_model=complexity, + n_layers=3, + ) + if include_transformer + else Dummy(), + # 6 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # MyLSTMLayer(complexity, complexity//2), + # 7 layer + MyConvLayer( + complexity, + complexity, + stride=2, + padding=1, + alpha=alpha, + drop_rate=drop_rate, + ), + MyConvLayer( + complexity, + complexity, + stride=2, + padding=1, + alpha=alpha, + drop_rate=drop_rate, + ), + # 8 layer + Flatten(), + nn.Linear(complexity * cfg.target_len // 2 // 2, complexity), + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + + self.real_fake = nn.Sequential( + nn.Linear(complexity, 1), + ) + self.labels = nn.Sequential( + nn.Linear(complexity, cfg.k_label), + ) + self.optimizer = get_optimizer(self.parameters()) + # maybe it will help! + # self.init_params() + + @property + def nb_of_parameters(self): + return number_of_parameters(self.parameters()) + + def forward(self, x): + x = self.main(x) + return self.real_fake(x), self.labels(x) diff --git a/models/LeakGAN_D.py b/models/discriminators/LeakGAN_D.py similarity index 62% rename from models/LeakGAN_D.py rename to models/discriminators/LeakGAN_D.py index afa398db..f5dc3840 100644 --- a/models/LeakGAN_D.py +++ b/models/discriminators/LeakGAN_D.py @@ -4,10 +4,10 @@ # @FileName : LeakGAN_D.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] @@ -15,5 +15,12 @@ class LeakGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2): - super(LeakGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) + super(LeakGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/MaliGAN_D.py b/models/discriminators/MaliGAN_D.py similarity index 62% rename from models/MaliGAN_D.py rename to models/discriminators/MaliGAN_D.py index a23eb8a0..37b3cfac 100644 --- a/models/MaliGAN_D.py +++ b/models/discriminators/MaliGAN_D.py @@ -4,10 +4,10 @@ # @FileName : MaliGAN_D.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] @@ -15,5 +15,12 @@ class MaliGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(MaliGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) + super(MaliGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/discriminators/RelGAN_D.py b/models/discriminators/RelGAN_D.py new file mode 100644 index 00000000..e3a109bd --- /dev/null +++ b/models/discriminators/RelGAN_D.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : RelGAN_D.py +# @Time : Created at 2019-04-25 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.discriminators.discriminator import CNNDiscriminator + +dis_filter_sizes = [2, 3, 4, 5] +dis_num_filters = [300, 300, 300, 300] + + +class RelGAN_D(CNNDiscriminator): + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(RelGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) + + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + self.feature_dim = sum(dis_num_filters) + self.emb_dim_single = int(embed_dim / num_rep) + + self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) + + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) + + self.highway = nn.Linear(self.feature_dim, self.feature_dim) + self.feature2out = nn.Linear(self.feature_dim, 100) + self.out2logits = nn.Linear(100, 1) + self.dropout = nn.Dropout(dropout) + + self.init_params() + + def forward(self, inp): + """ + Get logits of discriminator + :param inp: batch_size * seq_len * vocab_size + :return logits: [batch_size * num_rep] (1-D tensor) + """ + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim + + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] + pred = torch.cat(pools, 1) + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim + highway = self.highway(pred) + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway + + pred = self.feature2out(self.dropout(pred)) + logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] + + return logits diff --git a/models/discriminators/SentiGAN_D.py b/models/discriminators/SentiGAN_D.py new file mode 100644 index 00000000..f54ae108 --- /dev/null +++ b/models/discriminators/SentiGAN_D.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : SentiGAN_D.py +# @Time : Created at 2019-07-26 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. + +import torch.nn as nn + +from models.discriminators.discriminator import CNNDiscriminator, CNNClassifier + +dis_filter_sizes = [2, 3, 4, 5] +dis_num_filters = [200, 200, 200, 200] + +clas_filter_sizes = [2, 3, 4, 5] +clas_num_filters = [200] + + +class SentiGAN_D(CNNDiscriminator): + def __init__( + self, k_label, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2 + ): + super(SentiGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) + + self.feature2out = nn.Linear(self.feature_dim, k_label + 1) + + self.init_params() + + +# Classifier +class SentiGAN_C(CNNClassifier): + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(SentiGAN_C, self).__init__( + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + clas_filter_sizes, + clas_num_filters, + padding_idx, + gpu, + dropout, + ) + + # Use Glove + # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/SeqGAN_D.py b/models/discriminators/SeqGAN_D.py similarity index 62% rename from models/SeqGAN_D.py rename to models/discriminators/SeqGAN_D.py index 9e63b823..25ccad90 100644 --- a/models/SeqGAN_D.py +++ b/models/discriminators/SeqGAN_D.py @@ -4,10 +4,10 @@ # @FileName : SeqGAN_D.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] @@ -15,5 +15,12 @@ class SeqGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(SeqGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) + super(SeqGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/discriminator.py b/models/discriminators/discriminator.py similarity index 65% rename from models/discriminator.py rename to models/discriminators/discriminator.py index 77738b7d..b3331977 100644 --- a/models/discriminator.py +++ b/models/discriminators/discriminator.py @@ -18,8 +18,16 @@ class CNNDiscriminator(nn.Module): - def __init__(self, embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, gpu=False, - dropout=0.2): + def __init__( + self, + embed_dim, + vocab_size, + filter_sizes, + num_filters, + padding_idx, + gpu=False, + dropout=0.2, + ): super(CNNDiscriminator, self).__init__() self.embedding_dim = embed_dim self.vocab_size = vocab_size @@ -28,9 +36,12 @@ def __init__(self, embed_dim, vocab_size, filter_sizes, num_filters, padding_idx self.gpu = gpu self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes) - ]) + self.convs = nn.ModuleList( + [ + nn.Conv2d(1, n, (f, embed_dim)) + for (n, f) in zip(num_filters, filter_sizes) + ] + ) self.highway = nn.Linear(self.feature_dim, self.feature_dim) self.feature2out = nn.Linear(self.feature_dim, 2) self.dropout = nn.Dropout(dropout) @@ -54,12 +65,21 @@ def get_feature(self, inp): :param inp: batch_size * max_seq_len :return: batch_size * feature_dim """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim - convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * length] - pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim + convs = [ + F.relu(conv(emb)).squeeze(3) for conv in self.convs + ] # [batch_size * num_filter * length] + pools = [ + F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs + ] # [batch_size * num_filter] pred = torch.cat(pools, 1) # tensor: batch_size * feature_dim highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway return pred @@ -67,18 +87,26 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.dis_init == 'uniform': + if cfg.dis_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.dis_init == 'normal': + elif cfg.dis_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.dis_init == 'truncated_normal': + elif cfg.dis_init == "truncated_normal": truncated_normal_(param, std=stddev) class GRUDiscriminator(nn.Module): - - def __init__(self, embedding_dim, vocab_size, hidden_dim, feature_dim, max_seq_len, padding_idx, - gpu=False, dropout=0.2): + def __init__( + self, + embedding_dim, + vocab_size, + hidden_dim, + feature_dim, + max_seq_len, + padding_idx, + gpu=False, + dropout=0.2, + ): super(GRUDiscriminator, self).__init__() self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -86,8 +114,12 @@ def __init__(self, embedding_dim, vocab_size, hidden_dim, feature_dim, max_seq_l self.padding_idx = padding_idx self.gpu = gpu - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + self.gru = nn.GRU( + embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout + ) self.gru2hidden = nn.Linear(2 * 2 * hidden_dim, feature_dim) self.feature2out = nn.Linear(feature_dim, 2) self.dropout = nn.Dropout(dropout) @@ -125,7 +157,9 @@ def get_feature(self, inp): emb = emb.permute(1, 0, 2) # seq_len * batch_size * embedding_dim _, hidden = self.gru(emb, hidden) # 4 * batch_size * hidden_dim hidden = hidden.permute(1, 0, 2).contiguous() # batch_size * 4 * hidden_dim - out = self.gru2hidden(hidden.view(-1, 4 * self.hidden_dim)) # batch_size * 4 * hidden_dim + out = self.gru2hidden( + hidden.view(-1, 4 * self.hidden_dim) + ) # batch_size * 4 * hidden_dim feature = torch.tanh(out) # batch_size * feature_dim return feature @@ -134,20 +168,32 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.dis_init == 'uniform': + if cfg.dis_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.dis_init == 'normal': + elif cfg.dis_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.dis_init == 'truncated_normal': + elif cfg.dis_init == "truncated_normal": truncated_normal_(param, std=stddev) # Classifier class CNNClassifier(CNNDiscriminator): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, filter_sizes, num_filters, padding_idx, - gpu=False, dropout=0.25): - super(CNNClassifier, self).__init__(embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, - gpu, dropout) + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + filter_sizes, + num_filters, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CNNClassifier, self).__init__( + embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, gpu, dropout + ) self.k_label = k_label self.embed_dim = embed_dim @@ -157,9 +203,12 @@ def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, filter_ self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes) - ]) # vanilla + self.convs = nn.ModuleList( + [ + nn.Conv2d(1, n, (f, embed_dim)) + for (n, f) in zip(num_filters, filter_sizes) + ] + ) # vanilla # self.convs = nn.ModuleList([ # nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in # zip(num_filters, filter_sizes) @@ -179,11 +228,17 @@ def forward(self, inp): :param inp: batch_size * seq_len * vocab_size :return logits: [batch_size * num_rep] (1-D tensor) """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim # vanilla - convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * length] - pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] + convs = [ + F.relu(conv(emb)).squeeze(3) for conv in self.convs + ] # [batch_size * num_filter * length] + pools = [ + F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs + ] # [batch_size * num_filter] # RelGAN # cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] # pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] @@ -191,10 +246,15 @@ def forward(self, inp): pred = torch.cat(pools, 1) # batch_size * feature_dim # pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # RelGAN highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim pred = self.feature2out(self.dropout(pred)) - logits = self.out2logits(self.dropout(pred)).squeeze(1) # vanilla, batch_size * k_label + logits = self.out2logits(self.dropout(pred)).squeeze( + 1 + ) # vanilla, batch_size * k_label # logits = self.out2logits(self.dropout(pred.view(inp.size(0), -1))).squeeze(1) # RelGAN, batch_size * k_label return logits diff --git a/models/CatGAN_G.py b/models/generators/CatGAN_G.py similarity index 66% rename from models/CatGAN_G.py rename to models/generators/CatGAN_G.py index 60493e81..77743e5f 100644 --- a/models/CatGAN_G.py +++ b/models/generators/CatGAN_G.py @@ -4,7 +4,7 @@ # @FileName : CatGAN_G.py # @Time : Created at 2019-07-18 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,40 +12,63 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory class CatGAN_G(LSTMGenerator): - def __init__(self, k_label, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, - padding_idx, - gpu=False): - super(CatGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'catgan' + def __init__( + self, + k_label, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(CatGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "catgan" self.k_label = k_label - self.temperature = nn.Parameter(torch.Tensor([1.0]), requires_grad=False) # init value is 1.0 + self.temperature = nn.Parameter( + torch.Tensor([1.0]), requires_grad=False + ) # init value is 1.0 # Category matrix # self.cat_mat = nn.Parameter(torch.rand(self.k_label, embedding_dim), requires_grad=True) self.cat_mat = nn.Parameter(torch.eye(k_label), requires_grad=False) - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim - self.lstm = nn.LSTM(k_label + embedding_dim, self.hidden_dim, batch_first=True) + self.lstm = nn.LSTM( + k_label + embedding_dim, self.hidden_dim, batch_first=True + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=k_label + embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=k_label + embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -67,17 +90,25 @@ def forward(self, inp, hidden, label=None, need_hidden=False): :param hidden: memory size :param need_hidden: if return hidden, use for sampling """ - assert type(label) == torch.Tensor, 'missing label' + assert type(label) == torch.Tensor, "missing label" emb = self.embeddings(inp) # batch_size * len * embedding_dim # cat category vector label_onehot = F.one_hot(label, self.k_label).float() # batch_size * k_label - label_onehot_ex = label_onehot.unsqueeze(1).expand(-1, inp.size(1), -1) # batch_size * len * k_label - label_vec = torch.bmm(label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1)) # batch_size * len * embed_dim - emb = torch.cat((emb, label_vec), dim=-1) # batch_sie * len * (k_label + embed_dim) + label_onehot_ex = label_onehot.unsqueeze(1).expand( + -1, inp.size(1), -1 + ) # batch_size * len * k_label + label_vec = torch.bmm( + label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1) + ) # batch_size * len * embed_dim + emb = torch.cat( + (emb, label_vec), dim=-1 + ) # batch_sie * len * (k_label + embed_dim) out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # batch_size * seq_len * vocab_size # out = self.temperature * out # temperature pred = self.softmax(out) @@ -98,14 +129,20 @@ def step(self, inp, hidden, label=None): - hidden: next hidden - next_token: [batch_size], next sentence token """ - assert type(label) == torch.Tensor, 'missing label' + assert type(label) == torch.Tensor, "missing label" emb = self.embeddings(inp).unsqueeze(1) # cat category vector label_onehot = F.one_hot(label, self.k_label).float() # batch_size * k_label - label_onehot_ex = label_onehot.unsqueeze(1).expand(-1, 1, -1) # batch_size * 1 * k_label - label_vec = torch.bmm(label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1)) # batch_size * 1 * embed_dim - emb = torch.cat((emb, label_vec), dim=-1) # batch_sie * len * (k_label + embed_dim) + label_onehot_ex = label_onehot.unsqueeze(1).expand( + -1, 1, -1 + ) # batch_size * 1 * k_label + label_vec = torch.bmm( + label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1) + ) # batch_size * 1 * embed_dim + emb = torch.cat( + (emb, label_vec), dim=-1 + ) # batch_sie * len * (k_label + embed_dim) out, hidden = self.lstm(emb, hidden) gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1))) @@ -115,8 +152,14 @@ def step(self, inp, hidden, label=None): return pred, hidden, next_token - def sample(self, num_samples, batch_size, one_hot=False, label_i=None, - start_letter=cfg.start_letter): + def sample( + self, + num_samples, + batch_size, + one_hot=False, + label_i=None, + start_letter=cfg.start_letter, + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -126,7 +169,7 @@ def sample(self, num_samples, batch_size, one_hot=False, label_i=None, - samples: all samples """ global all_preds - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1 samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long() if one_hot: @@ -144,7 +187,7 @@ def sample(self, num_samples, batch_size, one_hot=False, label_i=None, for i in range(self.max_seq_len): pred, hidden, next_token = self.step(inp, hidden, label_t) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/CoT_G.py b/models/generators/CoT_G.py similarity index 58% rename from models/CoT_G.py rename to models/generators/CoT_G.py index 357ac73c..ab294c3c 100644 --- a/models/CoT_G.py +++ b/models/generators/CoT_G.py @@ -4,19 +4,23 @@ # @FileName : CoT_G.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class CoT_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(CoT_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'cot' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(CoT_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "cot" def get_loss(self, input, rewards): """ @@ -25,7 +29,9 @@ def get_loss(self, input, rewards): @param rewards: rewards form mediator, (batch size * seq_len) * vocab_size @return: """ - log_pred = self.forward(input, self.init_hidden(input.size(0))) # (batch_size * seq_len) * vocab_size + log_pred = self.forward( + input, self.init_hidden(input.size(0)) + ) # (batch_size * seq_len) * vocab_size g_pred = torch.exp(log_pred) loss = -torch.sum(g_pred * (rewards - log_pred)) / rewards.size(0) return loss diff --git a/models/generators/DGSAN_G.py b/models/generators/DGSAN_G.py new file mode 100644 index 00000000..04491282 --- /dev/null +++ b/models/generators/DGSAN_G.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : DGSAN_G.py +# @Time : Created at 2020/4/12 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. + +from models.generators.generator import LSTMGenerator + + +class DGSAN_G(LSTMGenerator): + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DGSAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dgsan" diff --git a/models/DPGAN_G.py b/models/generators/DPGAN_G.py similarity index 72% rename from models/DPGAN_G.py rename to models/generators/DPGAN_G.py index 121b62a3..c32d9dca 100644 --- a/models/DPGAN_G.py +++ b/models/generators/DPGAN_G.py @@ -10,13 +10,17 @@ import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class DPGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DPGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dpgan_g' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DPGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dpgan_g" def sample_teacher_forcing(self, inp): """ @@ -32,7 +36,9 @@ def sample_teacher_forcing(self, inp): pred = self.forward(inp, hidden) samples = torch.argmax(pred, dim=-1).view(batch_size, -1) - log_prob = F.nll_loss(pred, samples.view(-1), reduction='none').view(batch_size, -1) + log_prob = F.nll_loss(pred, samples.view(-1), reduction="none").view( + batch_size, -1 + ) # samples = torch.multinomial(torch.exp(log_prob), 1) return samples, log_prob diff --git a/models/EvoGAN_G.py b/models/generators/EvoGAN_G.py similarity index 77% rename from models/EvoGAN_G.py rename to models/generators/EvoGAN_G.py index b6ad52ea..b195a504 100644 --- a/models/EvoGAN_G.py +++ b/models/generators/EvoGAN_G.py @@ -4,7 +4,7 @@ # @FileName : EvoGAN_G.py # @Time : Created at 2019-07-09 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,20 +12,36 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory class EvoGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(EvoGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'evogan' - - self.temperature = nn.Parameter(torch.Tensor([1.0]), requires_grad=False) # init value is 1.0 - - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(EvoGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "evogan" + + self.temperature = nn.Parameter( + torch.Tensor([1.0]), requires_grad=False + ) # init value is 1.0 + + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, batch_first=True) @@ -33,14 +49,19 @@ def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, v else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -79,7 +100,9 @@ def step(self, inp, hidden): return pred, hidden, next_token, next_token_onehot, next_o - def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter): + def sample( + self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -103,7 +126,7 @@ def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_ for i in range(self.max_seq_len): pred, hidden, next_token, _, _ = self.step(inp, hidden) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/generators/FixemGAN_G.py b/models/generators/FixemGAN_G.py new file mode 100644 index 00000000..2c72f754 --- /dev/null +++ b/models/generators/FixemGAN_G.py @@ -0,0 +1,168 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +from utils.nn_helpers import ( + get_optimizer, + create_noise, + Concatenate, + Reshape, + MyConvLayerNorm, + MyConvTransposeLayer, + PositionalEncoding, + MyLSTMLayerNorm, + Dummy, +) + +import config as cfg +from models.generators.generator import LSTMGenerator + + +class Generator(LSTMGenerator): + def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): + super(Generator, self).__init__( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.target_len, + cfg.padding_idx, + ) + alpha = 0.2 + added_dim_pe = 0 + include_batch_norm = True + include_transformer = False + include_lstm = True + self.noise_size = noise_size + self.w2v = w2v + self.embedding_size = w2v_embedding_size + + self.main = nn.Sequential( + # 1 layer + Concatenate(1), + nn.Linear( + cfg.noise_size + cfg.k_label, cfg.target_len // 2 // 2 * complexity + ), + nn.BatchNorm1d(cfg.target_len // 2 // 2 * complexity), + nn.LeakyReLU(alpha), + Reshape(complexity, cfg.target_len // 2 // 2), + # 2 layer + MyConvLayerNorm(complexity, complexity, alpha=alpha), + # 3 layer + MyConvTransposeLayer( + complexity, + complexity, + stride=2, + output_padding=1, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 4 layer + MyConvTransposeLayer( + complexity, + complexity, + stride=2, + output_padding=1, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # adding/concatenating positional encoding + PositionalEncoding( + dim_pe=complexity, + max_len=cfg.target_len, + concatenate_pe=False, + ), + # 5 layer + MyConvLayerNorm( + complexity + added_dim_pe, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # adding/concatenating positional encoding + PositionalEncoding( + dim_pe=complexity, + max_len=cfg.target_len, + concatenate_pe=False, + ), + # 6 layer + MyTransformerEncoderLayer( + d_model=complexity + added_dim_pe, + n_layers=3, + ) + if include_transformer + else Dummy(), + # 7 layer + MyConvTransposeLayer( + complexity + added_dim_pe, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 8 layer + MyLSTMLayerNorm( + complexity, + complexity // 2, + ) + if include_lstm + else Dummy(), + # 9 layer + MyConvTransposeLayer( + complexity, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 10 layer + MyLSTMLayerNorm( + complexity, + complexity // 2, + ) + if include_lstm + else Dummy(), + # 11 layer + MyConvTransposeLayer( + complexity, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 12 layer + nn.Conv1d( + complexity, + cfg.w2v_embedding_size, + kernel_size=1, + stride=1, + padding=0, + ), + ) + self.optimizer = get_optimizer(self.parameters()) + + def forward(self, noise, target_labels): + target_labels = torch.nn.functional.one_hot( + target_labels, num_classes=cfg.k_label + ) + return self.main([noise, target_labels]) + + def sample( + self, num_samples, batch_size, label_i="random", start_letter=cfg.start_letter + ): + noise = create_noise(num_samples, self.noise_size, cfg.k_label) + if label_i != "random": + noise = (noise[0], torch.tensor(label_i).expand_as(noise[1])) + + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) + fakes = self.forward(*noise) + fakes = fakes.detach().cpu().numpy() + assert len(fakes.shape) == 3 + return [self.recover_sentence(fake) for fake in fakes] + + def recover_sentence(self, fake): + fake = fake.T + tokens = [] + for token_vector in fake: + token = self.w2v.wv.most_similar([token_vector])[0][0] + if token == cfg.padding_token: + continue + tokens.append(token) + return " ".join(tokens).strip() diff --git a/models/JSDGAN_G.py b/models/generators/JSDGAN_G.py similarity index 64% rename from models/JSDGAN_G.py rename to models/generators/JSDGAN_G.py index a35cf35b..a84d58c7 100644 --- a/models/JSDGAN_G.py +++ b/models/generators/JSDGAN_G.py @@ -4,21 +4,33 @@ # @FileName : JSDGAN_G.py # @Time : Created at 2019/11/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class JSDGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(JSDGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'jsdgan' + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(JSDGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "jsdgan" # RMC @@ -43,8 +55,12 @@ def JSD_loss(self, inp, target): """ batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - pred = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + pred = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(pred * target_onehot, dim=-1) # batch_size * seq_len # calculate probabilities of sentences @@ -55,19 +71,26 @@ def JSD_loss(self, inp, target): prob_data = prob_data.cuda() # calculate the reward - reward = torch.log(1. - torch.div(prob_data, prob_data + prob_gen)) # batch_size + reward = torch.log( + 1.0 - torch.div(prob_data, prob_data + prob_gen) + ) # batch_size # check if nan if torch.isnan(reward).sum() > 0: - print('Reward is nan!!!') + print("Reward is nan!!!") exit(1) - loss = torch.sum((prob_gen * reward).detach() * torch.sum(pred.double(), dim=-1)) + loss = torch.sum( + (prob_gen * reward).detach() * torch.sum(pred.double(), dim=-1) + ) return loss def min_max_normal(self, prob): - return torch.div(prob - torch.min(prob), torch.clamp(torch.max(prob) - torch.min(prob), min=1e-78)) + return torch.div( + prob - torch.min(prob), + torch.clamp(torch.max(prob) - torch.min(prob), min=1e-78), + ) def sigmoid_normal(self, prob): """push prob either close to 0 or 1""" diff --git a/models/LeakGAN_G.py b/models/generators/LeakGAN_G.py similarity index 69% rename from models/LeakGAN_G.py rename to models/generators/LeakGAN_G.py index bdde5bb2..53e69969 100644 --- a/models/LeakGAN_G.py +++ b/models/generators/LeakGAN_G.py @@ -4,7 +4,7 @@ # @FileName : LeakGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import math import time @@ -21,10 +21,19 @@ class LeakGAN_G(nn.Module): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, goal_size, - step_size, gpu=False): + def __init__( + self, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + goal_size, + step_size, + gpu=False, + ): super(LeakGAN_G, self).__init__() - self.name = 'leakgan' + self.name = "leakgan" self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -37,7 +46,9 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.gpu = gpu self.temperature = 1.5 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) self.worker = nn.LSTM(embedding_dim, hidden_dim) self.manager = nn.LSTM(goal_out_size, hidden_dim) @@ -49,7 +60,17 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.init_params() - def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log=False, train=False): + def forward( + self, + idx, + inp, + work_hidden, + mana_hidden, + feature, + real_goal, + no_log=False, + train=False, + ): """ Embeds input and sample on token at a time (seq_len = 1) @@ -69,16 +90,25 @@ def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log emb = self.embeddings(inp).unsqueeze(0) # 1 * batch_size * embed_dim # Manager - mana_out, mana_hidden = self.manager(feature, mana_hidden) # mana_out: 1 * batch_size * hidden_dim - mana_out = self.mana2goal(mana_out.permute([1, 0, 2])) # batch_size * 1 * goal_out_size + mana_out, mana_hidden = self.manager( + feature, mana_hidden + ) # mana_out: 1 * batch_size * hidden_dim + mana_out = self.mana2goal( + mana_out.permute([1, 0, 2]) + ) # batch_size * 1 * goal_out_size cur_goal = F.normalize(mana_out, dim=-1) _real_goal = self.goal2goal(real_goal) # batch_size * goal_size - _real_goal = F.normalize(_real_goal, p=2, dim=-1).unsqueeze(-1) # batch_size * goal_size * 1 + _real_goal = F.normalize(_real_goal, p=2, dim=-1).unsqueeze( + -1 + ) # batch_size * goal_size * 1 # Worker - work_out, work_hidden = self.worker(emb, work_hidden) # work_out: 1 * batch_size * hidden_dim - work_out = self.work2goal(work_out).view(-1, self.vocab_size, - self.goal_size) # batch_size * vocab_size * goal_size + work_out, work_hidden = self.worker( + emb, work_hidden + ) # work_out: 1 * batch_size * hidden_dim + work_out = self.work2goal(work_out).view( + -1, self.vocab_size, self.goal_size + ) # batch_size * vocab_size * goal_size # Sample token out = torch.matmul(work_out, _real_goal).squeeze(-1) # batch_size * vocab_size @@ -101,21 +131,31 @@ def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log return out, cur_goal, work_hidden, mana_hidden - def sample(self, num_samples, batch_size, dis, start_letter=cfg.start_letter, train=False): + def sample( + self, num_samples, batch_size, dis, start_letter=cfg.start_letter, train=False + ): """ Samples the network and returns num_samples samples of length max_seq_len. :return: samples: batch_size * max_seq_len """ num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1 - samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long() # larger than num_samples + samples = torch.zeros( + num_batch * batch_size, self.max_seq_len + ).long() # larger than num_samples fake_sentences = torch.zeros((batch_size, self.max_seq_len)) for b in range(num_batch): - leak_sample, _, _, _ = self.forward_leakgan(fake_sentences, dis, if_sample=True, no_log=False - , start_letter=start_letter, train=False) + leak_sample, _, _, _ = self.forward_leakgan( + fake_sentences, + dis, + if_sample=True, + no_log=False, + start_letter=start_letter, + train=False, + ) assert leak_sample.shape == (batch_size, self.max_seq_len) - samples[b * batch_size:(b + 1) * batch_size, :] = leak_sample + samples[b * batch_size : (b + 1) * batch_size, :] = leak_sample samples = samples[:num_samples, :] @@ -130,16 +170,22 @@ def pretrain_loss(self, target, dis, start_letter=cfg.start_letter): """ batch_size, seq_len = target.size() - _, feature_array, goal_array, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter) + _, feature_array, goal_array, leak_out_array = self.forward_leakgan( + target, dis, if_sample=False, no_log=False, start_letter=start_letter + ) # Manager loss - mana_cos_loss = self.manager_cos_loss(batch_size, feature_array, - goal_array) # batch_size * (seq_len / step_size) - manager_loss = -torch.sum(mana_cos_loss) / (batch_size * (seq_len // self.step_size)) + mana_cos_loss = self.manager_cos_loss( + batch_size, feature_array, goal_array + ) # batch_size * (seq_len / step_size) + manager_loss = -torch.sum(mana_cos_loss) / ( + batch_size * (seq_len // self.step_size) + ) # Worker loss - work_nll_loss = self.worker_nll_loss(target, leak_out_array) # batch_size * seq_len + work_nll_loss = self.worker_nll_loss( + target, leak_out_array + ) # batch_size * seq_len work_loss = torch.sum(work_nll_loss) / (batch_size * seq_len) return manager_loss, work_loss @@ -154,18 +200,31 @@ def adversarial_loss(self, target, rewards, dis, start_letter=cfg.start_letter): - rewards: batch_size * seq_len (discriminator rewards for each token) """ batch_size, seq_len = target.size() - _, feature_array, goal_array, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter, train=True) + _, feature_array, goal_array, leak_out_array = self.forward_leakgan( + target, + dis, + if_sample=False, + no_log=False, + start_letter=start_letter, + train=True, + ) # Manager Loss t0 = time.time() - mana_cos_loss = self.manager_cos_loss(batch_size, feature_array, - goal_array) # batch_size * (seq_len / step_size) - mana_loss = -torch.sum(rewards * mana_cos_loss) / (batch_size * (seq_len // self.step_size)) + mana_cos_loss = self.manager_cos_loss( + batch_size, feature_array, goal_array + ) # batch_size * (seq_len / step_size) + mana_loss = -torch.sum(rewards * mana_cos_loss) / ( + batch_size * (seq_len // self.step_size) + ) # Worker Loss - work_nll_loss = self.worker_nll_loss(target, leak_out_array) # batch_size * seq_len - work_cos_reward = self.worker_cos_reward(feature_array, goal_array) # batch_size * seq_len + work_nll_loss = self.worker_nll_loss( + target, leak_out_array + ) # batch_size * seq_len + work_cos_reward = self.worker_cos_reward( + feature_array, goal_array + ) # batch_size * seq_len work_loss = -torch.sum(work_nll_loss * work_cos_reward) / (batch_size * seq_len) return mana_loss, work_loss @@ -194,17 +253,23 @@ def manager_cos_loss(self, batch_size, feature_array, goal_array): # ===LeakGAN origin=== # get sub_feature and real_goal # batch_size, seq_len = sentences.size() - sub_feature = torch.zeros(batch_size, self.max_seq_len // self.step_size, self.goal_out_size) - real_goal = torch.zeros(batch_size, self.max_seq_len // self.step_size, self.goal_out_size) + sub_feature = torch.zeros( + batch_size, self.max_seq_len // self.step_size, self.goal_out_size + ) + real_goal = torch.zeros( + batch_size, self.max_seq_len // self.step_size, self.goal_out_size + ) for i in range(self.max_seq_len // self.step_size): idx = i * self.step_size - sub_feature[:, i, :] = feature_array[:, idx + self.step_size, :] - feature_array[:, idx, :] + sub_feature[:, i, :] = ( + feature_array[:, idx + self.step_size, :] - feature_array[:, idx, :] + ) if i == 0: real_goal[:, i, :] = self.goal_init[:batch_size, :] else: idx = (i - 1) * self.step_size + 1 - real_goal[:, i, :] = torch.sum(goal_array[:, idx:idx + 4, :], dim=1) + real_goal[:, i, :] = torch.sum(goal_array[:, idx : idx + 4, :], dim=1) # L2 noramlization sub_feature = F.normalize(sub_feature, p=2, dim=-1) @@ -220,7 +285,7 @@ def worker_nll_loss(self, target, leak_out_array): :return loss: batch_size * seq_len """ - loss_fn = nn.NLLLoss(reduction='none') + loss_fn = nn.NLLLoss(reduction="none") loss = loss_fn(leak_out_array.permute([0, 2, 1]), target) return loss @@ -232,26 +297,52 @@ def worker_cos_reward(self, feature_array, goal_array): :return: cos_loss: batch_size * seq_len """ for i in range(int(self.max_seq_len / self.step_size)): - real_feature = feature_array[:, i * self.step_size, :].unsqueeze(1).expand((-1, self.step_size, -1)) - feature_array[:, i * self.step_size:(i + 1) * self.step_size, :] = real_feature + real_feature = ( + feature_array[:, i * self.step_size, :] + .unsqueeze(1) + .expand((-1, self.step_size, -1)) + ) + feature_array[ + :, i * self.step_size : (i + 1) * self.step_size, : + ] = real_feature if i > 0: - sum_goal = torch.sum(goal_array[:, (i - 1) * self.step_size:i * self.step_size, :], dim=1, keepdim=True) + sum_goal = torch.sum( + goal_array[:, (i - 1) * self.step_size : i * self.step_size, :], + dim=1, + keepdim=True, + ) else: sum_goal = goal_array[:, 0, :].unsqueeze(1) - goal_array[:, i * self.step_size:(i + 1) * self.step_size, :] = sum_goal.expand((-1, self.step_size, -1)) - - offset_feature = feature_array[:, 1:, :] # f_{t+1}, batch_size * seq_len * goal_out_size - goal_array = goal_array[:, :self.max_seq_len, :] # batch_size * seq_len * goal_out_size + goal_array[ + :, i * self.step_size : (i + 1) * self.step_size, : + ] = sum_goal.expand((-1, self.step_size, -1)) + + offset_feature = feature_array[ + :, 1:, : + ] # f_{t+1}, batch_size * seq_len * goal_out_size + goal_array = goal_array[ + :, : self.max_seq_len, : + ] # batch_size * seq_len * goal_out_size sub_feature = offset_feature - goal_array # L2 normalization sub_feature = F.normalize(sub_feature, p=2, dim=-1) all_goal = F.normalize(goal_array, p=2, dim=-1) - cos_loss = F.cosine_similarity(sub_feature, all_goal, dim=-1) # batch_size * seq_len + cos_loss = F.cosine_similarity( + sub_feature, all_goal, dim=-1 + ) # batch_size * seq_len return cos_loss - def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter=cfg.start_letter, train=False): + def forward_leakgan( + self, + sentences, + dis, + if_sample, + no_log=False, + start_letter=cfg.start_letter, + train=False, + ): """ Get all feature and goals according to given sentences :param sentences: batch_size * max_seq_len, not include start token @@ -298,14 +389,24 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= if self.gpu: dis_inp = dis_inp.cuda() leak_inp = leak_inp.cuda() - feature = dis.get_feature(dis_inp).unsqueeze(0) # !!!note: 1 * batch_size * total_num_filters + feature = dis.get_feature(dis_inp).unsqueeze( + 0 + ) # !!!note: 1 * batch_size * total_num_filters feature_array[:, i, :] = feature.squeeze(0) # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.forward(i, leak_inp, work_hidden, mana_hidden, feature, - real_goal, no_log=no_log, train=train) + out, cur_goal, work_hidden, mana_hidden = self.forward( + i, + leak_inp, + work_hidden, + mana_hidden, + feature, + real_goal, + no_log=no_log, + train=train, + ) leak_out_array[:, i, :] = out # ===My implement according to paper=== @@ -320,14 +421,16 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= # Save goal and update real_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.goal_init[:batch_size, :] # Sample one token if not no_log: out = torch.exp(out) - out = torch.multinomial(out, 1).view(-1) # [batch_size] (sampling from each row) + out = torch.multinomial(out, 1).view( + -1 + ) # [batch_size] (sampling from each row) samples[:, i] = out.data leak_inp = out @@ -339,8 +442,9 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= def batchNLLLoss(self, target, dis, start_letter=cfg.start_letter): # loss_fn = nn.NLLLoss() # batch_size, seq_len = target.size() - _, _, _, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter) + _, _, _, leak_out_array = self.forward_leakgan( + target, dis, if_sample=False, no_log=False, start_letter=start_letter + ) nll_loss = torch.mean(self.worker_nll_loss(target, leak_out_array)) @@ -383,9 +487,9 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.gen_init == 'uniform': + if cfg.gen_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.gen_init == 'normal': + elif cfg.gen_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.gen_init == 'truncated_normal': + elif cfg.gen_init == "truncated_normal": truncated_normal_(param, std=stddev) diff --git a/models/MaliGAN_G.py b/models/generators/MaliGAN_G.py similarity index 61% rename from models/MaliGAN_G.py rename to models/generators/MaliGAN_G.py index 7baf0be6..01b9a987 100644 --- a/models/MaliGAN_G.py +++ b/models/generators/MaliGAN_G.py @@ -4,19 +4,23 @@ # @FileName : MaliGAN_G.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class MaliGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(MaliGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'maligan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(MaliGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "maligan" def adv_loss(self, inp, target, reward): """ @@ -31,8 +35,12 @@ def adv_loss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) diff --git a/models/RelGAN_G.py b/models/generators/RelGAN_G.py similarity index 79% rename from models/RelGAN_G.py rename to models/generators/RelGAN_G.py index c8294269..a856a453 100644 --- a/models/RelGAN_G.py +++ b/models/generators/RelGAN_G.py @@ -4,27 +4,41 @@ # @FileName : RelGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory class RelGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(RelGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'relgan' + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(RelGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "relgan" self.temperature = 1.0 # init value is 1.0 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, batch_first=True) @@ -32,15 +46,20 @@ def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, v else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() pass def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -79,7 +98,9 @@ def step(self, inp, hidden): return pred, hidden, next_token, next_token_onehot, next_o - def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter): + def sample( + self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -103,7 +124,7 @@ def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_ for i in range(self.max_seq_len): pred, hidden, next_token, _, _ = self.step(inp, hidden) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/SentiGAN_G.py b/models/generators/SentiGAN_G.py similarity index 71% rename from models/SentiGAN_G.py rename to models/generators/SentiGAN_G.py index f0fa0e45..145c4750 100644 --- a/models/SentiGAN_G.py +++ b/models/generators/SentiGAN_G.py @@ -4,20 +4,24 @@ # @FileName : SentiGAN_G.py # @Time : Created at 2019-07-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class SentiGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(SentiGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'sentigan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(SentiGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "sentigan" def forward(self, inp, hidden, need_hidden=False, use_log=True): """ @@ -31,7 +35,9 @@ def forward(self, inp, hidden, need_hidden=False, use_log=True): emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # batch_size * seq_len * vocab_size # out = self.temperature * out # temperature if use_log: @@ -57,8 +63,12 @@ def batchPGLoss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden, use_log=False).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * (1 - reward)) diff --git a/models/SeqGAN_G.py b/models/generators/SeqGAN_G.py similarity index 62% rename from models/SeqGAN_G.py rename to models/generators/SeqGAN_G.py index 86dd7c86..9859505b 100644 --- a/models/SeqGAN_G.py +++ b/models/generators/SeqGAN_G.py @@ -4,19 +4,23 @@ # @FileName : SeqGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class SeqGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(SeqGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'seqgan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(SeqGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "seqgan" def batchPGLoss(self, inp, target, reward): """ @@ -31,8 +35,12 @@ def batchPGLoss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) diff --git a/models/generator.py b/models/generators/generator.py similarity index 77% rename from models/generator.py rename to models/generators/generator.py index e1ab3fa8..70401242 100644 --- a/models/generator.py +++ b/models/generators/generator.py @@ -16,10 +16,11 @@ class LSTMGenerator(nn.Module): - - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): super(LSTMGenerator, self).__init__() - self.name = 'vanilla' + self.name = "vanilla" self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -30,7 +31,9 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.temperature = 1.0 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.lstm2out = nn.Linear(hidden_dim, vocab_size) self.softmax = nn.LogSoftmax(dim=-1) @@ -49,7 +52,9 @@ def forward(self, inp, hidden, need_hidden=False): emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # (batch_size * seq_len) * vocab_size # out = self.temperature * out # temperature pred = self.softmax(out) @@ -75,9 +80,13 @@ def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): inp = inp.cuda() for i in range(self.max_seq_len): - out, hidden = self.forward(inp, hidden, need_hidden=True) # out: batch_size * vocab_size - next_token = torch.multinomial(torch.exp(out), 1) # batch_size * 1 (sampling from each row) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1) + out, hidden = self.forward( + inp, hidden, need_hidden=True + ) # out: batch_size * vocab_size + next_token = torch.multinomial( + torch.exp(out), 1, replacement=True + ) # batch_size * 1 (sampling from each row) + samples[b * batch_size : (b + 1) * batch_size, i] = next_token.view(-1) inp = next_token.view(-1) samples = samples[:num_samples] @@ -87,18 +96,13 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.gen_init == 'uniform': + if cfg.gen_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.gen_init == 'normal': + elif cfg.gen_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.gen_init == 'truncated_normal': + elif cfg.gen_init == "truncated_normal": truncated_normal_(param, std=stddev) - def init_oracle(self): - for param in self.parameters(): - if param.requires_grad: - torch.nn.init.normal_(param, mean=0, std=1) - def init_hidden(self, batch_size=cfg.batch_size): h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) diff --git a/models/relational_rnn_general.py b/models/relational_rnn_general.py index e2d5178c..c12fbfe5 100644 --- a/models/relational_rnn_general.py +++ b/models/relational_rnn_general.py @@ -38,8 +38,20 @@ class RelationalMemory(nn.Module): ValueError: attention_mlp_layers is < 1. """ - def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, forget_bias=1., input_bias=0., - gate_style='unit', attention_mlp_layers=2, key_size=None, return_all_outputs=False): + def __init__( + self, + mem_slots, + head_size, + input_size, + num_heads=1, + num_blocks=1, + forget_bias=1.0, + input_bias=0.0, + gate_style="unit", + attention_mlp_layers=2, + key_size=None, + return_all_outputs=False, + ): super(RelationalMemory, self).__init__() ########## generic parameters for RMC ########## @@ -54,18 +66,22 @@ def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, self.mem_slots_plus_input = self.mem_slots + 1 if num_blocks < 1: - raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks)) + raise ValueError("num_blocks must be >=1. Got: {}.".format(num_blocks)) self.num_blocks = num_blocks - if gate_style not in ['unit', 'memory', None]: + if gate_style not in ["unit", "memory", None]: raise ValueError( - 'gate_style must be one of [\'unit\', \'memory\', None]. got: ' - '{}.'.format(gate_style)) + "gate_style must be one of ['unit', 'memory', None]. got: " + "{}.".format(gate_style) + ) self.gate_style = gate_style if attention_mlp_layers < 1: - raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format( - attention_mlp_layers)) + raise ValueError( + "attention_mlp_layers must be >= 1. Got: {}.".format( + attention_mlp_layers + ) + ) self.attention_mlp_layers = attention_mlp_layers self.key_size = key_size if key_size else self.head_size @@ -81,12 +97,20 @@ def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, # just using one big param is more efficient, rather than this line # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)] self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size) - self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size]) + self.qkv_layernorm = nn.LayerNorm( + [self.mem_slots_plus_input, self.total_qkv_size] + ) # used for attend_over_memory function - self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers) - self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) - self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) + self.attention_mlp = nn.ModuleList( + [nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers + ) + self.attended_memory_layernorm = nn.LayerNorm( + [self.mem_slots_plus_input, self.mem_size] + ) + self.attended_memory_layernorm2 = nn.LayerNorm( + [self.mem_slots_plus_input, self.mem_size] + ) ########## parameters for initial embedded input projection ########## self.input_size = input_size @@ -135,7 +159,7 @@ def initial_state(self, batch_size, trainable=False): # truncation. take the first 'self.mem_size' components elif self.mem_size < self.mem_slots: - init_state = init_state[:, :, :self.mem_size] + init_state = init_state[:, :, : self.mem_size] return init_state @@ -168,10 +192,12 @@ def multihead_attention(self, memory): qkv_transpose = qkv_reshape.permute(0, 2, 1, 3) # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size] - q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1) + q, k, v = torch.split( + qkv_transpose, [self.key_size, self.key_size, self.value_size], -1 + ) # scale q with d_k, the dimensionality of the key vectors - q *= (self.key_size ** -0.5) + q = q * (self.key_size**-0.5) # make it [B, H, N, N] dot_product = torch.matmul(q, k.permute(0, 1, 3, 2)) @@ -182,7 +208,9 @@ def multihead_attention(self, memory): # [B, H, N, V] => [B, N, H, V] => [B, N, H*V] output_transpose = output.permute(0, 2, 1, 3).contiguous() - new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1)) + new_memory = output_transpose.view( + (output_transpose.shape[0], output_transpose.shape[1], -1) + ) return new_memory @@ -200,9 +228,9 @@ def calculate_gate_size(self): Returns: The per sample, per head parameter size of each gate. """ - if self.gate_style == 'unit': + if self.gate_style == "unit": return self.mem_size - elif self.gate_style == 'memory': + elif self.gate_style == "memory": return 1 else: # self.gate_style == None return 0 @@ -231,7 +259,8 @@ def create_gates(self, inputs, memory): if len(inputs.shape) == 3: if inputs.shape[1] > 1: raise ValueError( - "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1") + "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1" + ) inputs = inputs.view(inputs.shape[0], -1) # matmul for equation 4 and 5 # there is no output gate, so equation 6 is not implemented @@ -243,7 +272,9 @@ def create_gates(self, inputs, memory): # this completes the equation 4 and 5 gates = gate_memory + gate_inputs - gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2) + gates = torch.split( + gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2 + ) input_gate, forget_gate = gates assert input_gate.shape[2] == forget_gate.shape[2] @@ -310,7 +341,7 @@ def forward_step(self, inputs, memory, treat_input_as_matrix=False): n = inputs_reshape.shape[1] next_memory = next_memory[:, :-n, :] - if self.gate_style == 'unit' or self.gate_style == 'memory': + if self.gate_style == "unit" or self.gate_style == "memory": # these gates are sigmoid-applied ones for equation 7 input_gate, forget_gate = self.create_gates(inputs_reshape, memory) # equation 7 calculation @@ -345,6 +376,7 @@ def forward(self, inputs, memory, treat_input_as_matrix=False): else: return logit.unsqueeze(1), memory + # ########## DEBUG: unit test code ########## # input_size = 32 # seq_length = 20 diff --git a/run/run_catgan.py b/run/run_catgan.py index 8efef61e..71e4b78e 100644 --- a/run/run_catgan.py +++ b/run/run_catgan.py @@ -4,7 +4,7 @@ # @FileName : run_catgan.py # @Time : Created at 2019-08-04 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys from subprocess import call @@ -15,26 +15,30 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' -rootdir = '../' -scriptname = 'main.py' +executable = "python" +rootdir = "../" +scriptname = "main.py" # ===Program=== # CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) -run_model = ['catgan', 'catgan', 'catgan', 'evogan', 'evogan', 'evogan'] +run_model = ["evogan", "catgan", "catgan", "evogan", "evogan", "evogan"] k_label = 2 CUDA = int(True) ora_pretrain = int(True) @@ -43,19 +47,19 @@ MLE_train_epoch = 150 clas_pre_epoch = 5 ADV_train_epoch = 2000 -tips = '{} experiments' +tips = "{} experiments" # ===Oracle or Real=== -if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book', 'oracle', 'image_coco', 'emnlp_news'] +if_real_data = [int(True), int(False), int(True), int(False), int(True), int(True)] +dataset = ["amazon_app_book", "oracle", "mr15", "oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== n_parent = 1 -loss_type = 'ragan' -mu_type = 'ragan rsgan' -eval_type = 'Ra' -temp_adpt = 'exp' +loss_type = "ragan" +mu_type = "ragan rsgan" +eval_type = "Ra" +temp_adpt = "exp" temperature = [1, 100, 100, 1, 100, 100] d_out_mean = int(True) lambda_fq = 1.0 @@ -64,10 +68,9 @@ # === Basic Param === data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "truncated_normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -101,72 +104,117 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model[job_id], - '--k_label', k_label, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model[job_id], + "--k_label", + k_label, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', ora_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--clas_pre_epoch', clas_pre_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips.format(run_model[job_id]), - + "--ora_pretrain", + ora_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--clas_pre_epoch", + clas_pre_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips.format(run_model[job_id]), # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # CatGAN Param - '--n_parent', n_parent, - '--loss_type', loss_type, - '--mu_type', mu_type, - '--eval_type', eval_type, - '--temp_adpt', temp_adpt, - '--temperature', temperature[job_id], - '--d_out_mean', d_out_mean, - '--lambda_fq', lambda_fq, - '--lambda_fd', lambda_fd, - '--eval_b_num', eval_b_num, - + "--n_parent", + n_parent, + "--loss_type", + loss_type, + "--mu_type", + mu_type, + "--eval_type", + eval_type, + "--temp_adpt", + temp_adpt, + "--temperature", + temperature[job_id], + "--d_out_mean", + d_out_mean, + "--lambda_fq", + lambda_fq, + "--lambda_fd", + lambda_fd, + "--eval_b_num", + eval_b_num, # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--gen_adv_lr', gen_adv_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--gen_adv_lr", + gen_adv_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--mem_slots', mem_slots, - '--num_heads', num_heads, - '--head_size', head_size[job_id], - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--mem_slots", + mem_slots, + "--num_heads", + num_heads, + "--head_size", + head_size[job_id], # Discriminator - '--adv_d_step', ADV_d_step, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - '--num_rep', num_rep, - + "--adv_d_step", + ADV_d_step, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, + "--num_rep", + num_rep, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_clas_acc', use_clas_acc, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_cot.py b/run/run_cot.py index 92f4f03c..3f0661db 100644 --- a/run/run_cot.py +++ b/run/run_cot.py @@ -4,7 +4,7 @@ # @FileName : run_cot.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,43 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'cot' +run_model = "cot" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 0 ADV_train_epoch = 20000 -tips = 'CoT experiments' +tips = "CoT experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'normal' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "normal" batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -77,49 +80,74 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--adv_d_step', ADV_d_step, - + "--adv_d_step", + ADV_d_step, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_dgsan.py b/run/run_dgsan.py index b7085283..33bc078a 100644 --- a/run/run_dgsan.py +++ b/run/run_dgsan.py @@ -4,7 +4,7 @@ # @FileName : run_dgsan.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,42 +16,45 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = '/home/zhiwei/.virtualenvs/zhiwei/bin/python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "/home/zhiwei/.virtualenvs/zhiwei/bin/python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'dgsan' +run_model = "dgsan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 0 ADV_train_epoch = 200 -tips = 'DGSAN experiments' +tips = "DGSAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' -samples_num = 10000 +model_type = "vanilla" +gen_init = "truncated_normal" batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -72,44 +75,67 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_dpgan.py b/run/run_dpgan.py index 60529c3e..7b7496e1 100644 --- a/run/run_dpgan.py +++ b/run/run_dpgan.py @@ -16,43 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 2 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = '/home/zhiwei/.virtualenvs/zhiwei/bin/python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "/home/zhiwei/.virtualenvs/zhiwei/bin/python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'dpgan' +run_model = "dpgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 120 ADV_train_epoch = 200 -tips = 'DPGAN experiments' +tips = "DPGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -84,56 +87,89 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, - '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, + "--device", + gpu_id, # comment for auto GPU + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_fixem.py b/run/run_fixem.py new file mode 100644 index 00000000..9650f93a --- /dev/null +++ b/run/run_fixem.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : run_catgan.py +# @Time : Created at 2019-08-04 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. +import sys +from subprocess import call + +import os + +# Job id and gpu_id +if len(sys.argv) > 2: + job_id = int(sys.argv[1]) + gpu_id = str(sys.argv[2]) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) +elif len(sys.argv) > 1: + job_id = int(sys.argv[1]) + gpu_id = 0 + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) +else: + job_id = 0 + gpu_id = 0 + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) + +# Executables +executable = "python" +rootdir = "../" +scriptname = "main.py" + +# ===Program=== +# FixemGAN: General text generation model +if_test = int(False) +run_model = [ + "fixemgan", + "cat_fixemgan", + "fixemgan", + "cat_fixemgan", + "fixemgan", + "fixemgan", + "fixemgan", + "fixemgan", + "cat_fixemgan", +] +k_label = 2 +CUDA = int(True) +noise_size = 1000 +max_epochs = 20 +batches_per_epoch = 200 +samples_num = 100 # sample for metrics +tips = "{} experiments" + +# ===Oracle or Real=== +if_real_data = [ + int(True), + int(True), + int(True), + int(True), + int(True), + int(True), + int(True), + int(False), + int(False), +] +dataset = [ + "image_coco", + "mr20", + "mr20", + "mr15", + "mr15", + "amazon_app_book", + "emnlp_news", + "oracle", + "oracle", +] +w2v_embedding_size = 512 # low on ram #hyperparam +w2v_window = 5 +w2v_min_count = 30 +w2v_workers = 30 +w2v_samples_num = 5_000_000 +vocab_size = 5000 + +# === Basic Param === +data_shuffle = int(False) +model_type = "fixem" +loss_type = "fixem" +gen_init = "truncated_normal" +dis_init = "uniform" +batch_size = 64 +target_len = [ + 16, + 20, + 20, + 16, + 16, + 40, + 48, + 20, + 20, +] # architechture requires to be divisible by 4 +real_fake_coeff = 1.0 +labels_coeff = 1.0 +diversity_coeff = 1.0 + +# ===Generator=== +generator_complexity = 768 # hyperparam + +# ===Discriminator=== +discriminator_complexity = 512 # hyperparam + +# ===Metrics=== +use_nll_oracle = int(True) +use_nll_gen = int(False) +use_nll_div = int(False) +use_bleu = int(True) +use_self_bleu = int(True) +use_clas_acc = int(True) +use_ppl = int(False) + +args = [ + # Program + "--if_test", + if_test, + "--run_model", + run_model[job_id], + "--k_label", + k_label, + "--cuda", + CUDA, + # '--device', gpu_id, # comment for auto GPU + "--tips", + tips.format(run_model[job_id]), + # Oracle or Real + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size, + # W2V embeddings + "--w2v_embedding_size", + w2v_embedding_size, + "--w2v_window", + w2v_window, + "--w2v_min_count", + w2v_min_count, + "--w2v_workers", + w2v_workers, + "--w2v_samples_num", + w2v_samples_num, + # FixemGAN Param + "--loss_type", + loss_type, + "--max_epochs", + max_epochs, + "--batches_per_epoch", + batches_per_epoch, + "--noise_size", + noise_size, + "--target_len", + target_len[job_id], + "--batch_size", + batch_size, + "--real_fake_coeff", + real_fake_coeff, + "--labels_coeff", + labels_coeff, + "--diversity_coeff", + diversity_coeff, + # Generator + "--generator_complexity", + generator_complexity, + # Discriminator + "--discriminator_complexity", + discriminator_complexity, + # Metrics + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, +] + +args = list(map(str, args)) +my_env = os.environ.copy() +call([executable, scriptname] + args, env=my_env, cwd=rootdir) diff --git a/run/run_jsdgan.py b/run/run_jsdgan.py index 5a6e2111..60c88964 100644 --- a/run/run_jsdgan.py +++ b/run/run_jsdgan.py @@ -4,7 +4,7 @@ # @FileName : run_jsdgan.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,41 +16,44 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'jsdgan' +run_model = "jsdgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) MLE_train_epoch = 0 # no pre-training ADV_train_epoch = 500 -tips = 'JSDGAN experiments' +tips = "JSDGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -72,44 +75,67 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_leakgan.py b/run/run_leakgan.py index bb45d796..507d4057 100644 --- a/run/run_leakgan.py +++ b/run/run_leakgan.py @@ -4,7 +4,7 @@ # @FileName : run_leakgan.py # @Time : Created at 2019-05-27 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,24 +16,28 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'leakgan' +run_model = "leakgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) @@ -41,19 +45,18 @@ MLE_train_epoch = 8 ADV_train_epoch = 200 inter_epoch = 10 -tips = 'LeakGAN experiments' +tips = "LeakGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.0015 @@ -87,59 +90,94 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--inter_epoch', inter_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--inter_epoch", + inter_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--goal_size', goal_size, - '--step_size', step_size, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--goal_size", + goal_size, + "--step_size", + step_size, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_maligan.py b/run/run_maligan.py index 24424213..20b7f49f 100644 --- a/run/run_maligan.py +++ b/run/run_maligan.py @@ -4,7 +4,7 @@ # @FileName : run_maligan.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,43 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'maligan' +run_model = "maligan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 80 ADV_train_epoch = 200 -tips = 'MaliGAN experiments' +tips = "MaliGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -84,56 +87,88 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step[job_id], - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step[job_id], + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_relgan.py b/run/run_relgan.py index 72cff2ba..20f3e649 100644 --- a/run/run_relgan.py +++ b/run/run_relgan.py @@ -4,7 +4,7 @@ # @FileName : run_relgan.py # @Time : Created at 2019-05-28 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,45 +16,49 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'relgan' +run_model = "relgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 150 ADV_train_epoch = 3000 -tips = 'RelGAN experiments' +tips = "RelGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] -loss_type = 'rsgan' +dataset = ["oracle", "image_coco", "emnlp_news"] +loss_type = "rsgan" vocab_size = [5000, 0, 0] -temp_adpt = 'exp' +temp_adpt = "exp" temperature = [1, 100, 100] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "truncated_normal" +dis_init = "uniform" samples_num = 10000 batch_size = 64 max_seq_len = 20 @@ -88,60 +92,98 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--loss_type', loss_type, - '--vocab_size', vocab_size[job_id], - '--temp_adpt', temp_adpt, - '--temperature', temperature[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--loss_type", + loss_type, + "--vocab_size", + vocab_size[job_id], + "--temp_adpt", + temp_adpt, + "--temperature", + temperature[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--gen_adv_lr', gen_adv_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--samples_num", + samples_num, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--gen_adv_lr", + gen_adv_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--mem_slots', mem_slots, - '--num_heads', num_heads, - '--head_size', head_size, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--mem_slots", + mem_slots, + "--num_heads", + num_heads, + "--head_size", + head_size, # Discriminator - '--adv_d_step', ADV_d_step, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - '--num_rep', num_rep, - + "--adv_d_step", + ADV_d_step, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, + "--num_rep", + num_rep, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_sentigan.py b/run/run_sentigan.py index fb5c0e38..8e5f6252 100644 --- a/run/run_sentigan.py +++ b/run/run_sentigan.py @@ -17,24 +17,28 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'sentigan' +run_model = "sentigan" k_label = 2 CUDA = int(True) oracle_pretrain = int(True) @@ -43,19 +47,18 @@ MLE_train_epoch = 120 clas_pre_epoch = 5 ADV_train_epoch = 100 -tips = 'SentiGAN experiments' +tips = "SentiGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book'] +dataset = ["oracle", "mr15", "amazon_app_book"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -88,59 +91,94 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--k_label', k_label, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--k_label", + k_label, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--clas_pre_epoch', clas_pre_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--clas_pre_epoch", + clas_pre_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_clas_acc', use_clas_acc, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_seqgan.py b/run/run_seqgan.py index 7d98b9cd..701d9364 100644 --- a/run/run_seqgan.py +++ b/run/run_seqgan.py @@ -4,7 +4,7 @@ # @FileName : run_seqgan.py # @Time : Created at 2019-05-27 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,43 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'seqgan' +run_model = "seqgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 120 ADV_train_epoch = 200 -tips = 'SeqGAN experiments' +tips = "SeqGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' -samples_num = 10000 +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -84,56 +87,88 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/sweep.yml b/sweep.yml new file mode 100644 index 00000000..4738dd30 --- /dev/null +++ b/sweep.yml @@ -0,0 +1,29 @@ +program: train.py +method: bayes +metric: + goal: minimize + name: Overal_score +parameters: + discriminator_complexity: + max: 256 # must be multiplied by 4 + min: 64 + distribution: int_uniform + generator_complexity: + max: 256 # must be multiplied by 4 + min: 64 + distribution: int_uniform + w2v_embedding_size: + values: + - 512 + real_fake_coeff: + max: 1.5 + min: 1.0 + distribution: uniform + labels_coeff: + max: 2.5 + min: 2.0 + distribution: uniform + diversity_coeff: + max: 2.5 + min: 2.0 + distribution: uniform diff --git a/utils/cat_data_loader.py b/utils/cat_data_loader.py index 63bd803c..2881fcea 100644 --- a/utils/cat_data_loader.py +++ b/utils/cat_data_loader.py @@ -4,7 +4,7 @@ # @FileName : cat_data_loader.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import random @@ -37,18 +37,24 @@ def __init__(self, samples_list, shuffle=None): dataset=GANDataset(self.__read_data__(samples_list)), batch_size=self.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') - self.label = self._all_data_('label') # from 0 to k-1, different from Discriminator label + self.input = self._all_data_("input") + self.target = self._all_data_("target") + self.label = self._all_data_( + "label" + ) # from 0 to k-1, different from Discriminator label def __read_data__(self, samples_list): """ input: same as target, but start with start_letter. """ inp, target, label = self.prepare(samples_list) - all_data = [{'input': i, 'target': t, 'label': l} for (i, t, l) in zip(inp, target, label)] + all_data = [ + {"input": i, "target": t, "label": l} + for (i, t, l) in zip(inp, target, label) + ] return all_data def random_batch(self): @@ -57,7 +63,9 @@ def random_batch(self): return list(self.loader)[idx] def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) def prepare(self, samples_list, gpu=False): """Add start_letter to samples as inp, target same as samples""" @@ -65,12 +73,12 @@ def prepare(self, samples_list, gpu=False): target = all_samples inp = torch.zeros(all_samples.size()).long() inp[:, 0] = self.start_letter - inp[:, 1:] = target[:, :self.max_seq_len - 1] + inp[:, 1:] = target[:, : self.max_seq_len - 1] label = torch.zeros(all_samples.size(0)).long() for idx in range(len(samples_list)): start = sum([samples_list[i].size(0) for i in range(idx)]) - label[start: start + samples_list[idx].size(0)] = idx + label[start : start + samples_list[idx].size(0)] = idx # shuffle perm = torch.randperm(inp.size(0)) @@ -105,14 +113,15 @@ def __init__(self, samples_list, given_target=None, shuffle=None): dataset=GANDataset(self.__read_data__(samples_list, given_target)), batch_size=self.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') + self.input = self._all_data_("input") + self.target = self._all_data_("target") def __read_data__(self, samples_list, given_target=None): inp, target = self.prepare(samples_list, given_target) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -121,7 +130,9 @@ def random_batch(self): # return next(iter(self.loader)) def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) @staticmethod def prepare(samples_list, given_target=None, detach=True, gpu=False): @@ -135,7 +146,10 @@ def prepare(samples_list, given_target=None, detach=True, gpu=False): - inp: sentences - target: label index, 0-label_0, 1-label_1, ..., k-label_k """ - if len(samples_list) == 1 and given_target is not None: + if type(samples_list[0][0][0]) == str: # directly generated text + inp = torch.zeros(1) + target = torch.zeros(1) + elif len(samples_list) == 1 and given_target is not None: inp = samples_list[0] if detach: inp = inp.detach() @@ -151,7 +165,7 @@ def prepare(samples_list, given_target=None, detach=True, gpu=False): inp = inp.long() for idx in range(1, len(samples_list)): start = sum([samples_list[i].size(0) for i in range(idx)]) - target[start: start + samples_list[idx].size(0)] = idx + target[start : start + samples_list[idx].size(0)] = idx # shuffle perm = torch.randperm(inp.size(0)) diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py new file mode 100644 index 00000000..2ebe9a39 --- /dev/null +++ b/utils/create_embeddings.py @@ -0,0 +1,37 @@ +from gensim.models import Word2Vec +from tqdm import tqdm +from pathlib import Path + +import config as cfg +from utils.text_process import text_file_iterator + + +class MultipleFilesEmbeddingIterator: + def __init__(self, files): + self.files = files + + def __iter__(self): + for file in tqdm(self.files, desc="iterating files"): + for tokens in text_file_iterator(file): + yield [cfg.padding_token] * 5 + tokens + + +class EmbeddingsTrainer: + def __init__(self, sources, save_filename): + self.sources = sources + self.save_filename = save_filename + + def make_embeddings(self): + w2v = Word2Vec( + sentences=MultipleFilesEmbeddingIterator(self.sources), + size=cfg.w2v_embedding_size, + window=cfg.w2v_window, + min_count=cfg.w2v_min_count, + workers=cfg.w2v_workers, + ) + Path(self.save_filename).parents[0].mkdir(parents=True, exist_ok=True) + w2v.save(self.save_filename) + + +def load_embedding(path): + return Word2Vec.load(path) diff --git a/utils/data_loader.py b/utils/data_loader.py index 3d7ed791..d67fd926 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -4,13 +4,24 @@ # @FileName : data_loader.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import random + +import numpy as np +import torch from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm, trange -from utils.text_process import * +import config as cfg +from utils.text_process import ( + tokens_to_tensor, + get_tokenlized, + load_dict, + load_test_dict, + vectorize_sentence, +) class GANDataset(Dataset): @@ -26,37 +37,44 @@ def __len__(self): class GenDataIter: def __init__(self, samples, if_test_data=False, shuffle=None): - self.batch_size = cfg.batch_size - self.max_seq_len = cfg.max_seq_len - self.start_letter = cfg.start_letter + self.samples = samples + if type(self.samples) == str: # we received filename + self.samples = get_tokenlized(self.samples) + self.shuffle = cfg.data_shuffle if not shuffle else shuffle + if cfg.if_real_data: self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) if if_test_data: # used for the classifier self.word2idx_dict, self.idx2word_dict = load_test_dict(cfg.dataset) self.loader = DataLoader( - dataset=GANDataset(self.__read_data__(samples)), - batch_size=self.batch_size, + dataset=GANDataset(self.__read_data__(self.samples)), + batch_size=cfg.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') + self.input = self._all_data_("input") + self.target = self._all_data_("target") def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - # global all_data - if isinstance(samples, torch.Tensor): # Tensor - inp, target = self.prepare(samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] - elif isinstance(samples, str): # filename - inp, target = self.load_data(samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] - else: - all_data = None + if isinstance(samples[0], str) or isinstance( + samples[0][0], str + ): # list of strings + # we directly generated string, skip NLL + return [ + {"input": i, "target": t} + for i, t in zip(torch.zeros(2), torch.zeros(2)) + ] + if isinstance(samples[0], list): + # need to transform to indexes + samples = tokens_to_tensor(samples, self.word2idx_dict) + inp, target = self.prepare_for_NLL(samples) + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -65,46 +83,45 @@ def random_batch(self): return list(self.loader)[idx] def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) + + @property + def tokens(self): + """Returns samples in form of list of tensors, if input tensor, + or list of tokens in case if input string.""" + if type(self.samples[0]) == str: # we have list of strings + return [smpl.split() for smpl in self.samples] + return list(self.samples) @staticmethod - def prepare(samples, gpu=False): + def prepare_for_NLL(samples, gpu=False): """Add start_letter to samples as inp, target same as samples""" inp = torch.zeros(samples.size()).long() target = samples inp[:, 0] = cfg.start_letter - inp[:, 1:] = target[:, :cfg.max_seq_len - 1] + inp[:, 1:] = target[:, : cfg.max_seq_len - 1] if gpu: return inp.cuda(), target.cuda() return inp, target - def load_data(self, filename): - """Load real data from local file""" - self.tokens = get_tokenlized(filename) - samples_index = tokens_to_tensor(self.tokens, self.word2idx_dict) - return self.prepare(samples_index) - class DisDataIter: def __init__(self, pos_samples, neg_samples, shuffle=None): - self.batch_size = cfg.batch_size - self.max_seq_len = cfg.max_seq_len - self.start_letter = cfg.start_letter self.shuffle = cfg.data_shuffle if not shuffle else shuffle self.loader = DataLoader( dataset=GANDataset(self.__read_data__(pos_samples, neg_samples)), - batch_size=self.batch_size, + batch_size=cfg.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) def __read_data__(self, pos_samples, neg_samples): - """ - input: same as target, but start with start_letter. - """ inp, target = self.prepare(pos_samples, neg_samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -113,9 +130,11 @@ def random_batch(self): def prepare(self, pos_samples, neg_samples, gpu=False): """Build inp and target""" - inp = torch.cat((pos_samples, neg_samples), dim=0).long().detach() # !!!need .detach() + inp = ( + torch.cat((pos_samples, neg_samples), dim=0).long().detach() + ) # !!!need .detach() target = torch.ones(inp.size(0)).long() - target[pos_samples.size(0):] = 0 + target[pos_samples.size(0) :] = 0 # shuffle perm = torch.randperm(inp.size(0)) @@ -125,3 +144,76 @@ def prepare(self, pos_samples, neg_samples, gpu=False): if gpu: return inp.cuda(), target.cuda() return inp, target + + +class DataSupplier: + def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): + labels, tokenized = zip( + *[ + (label, tokens) + for label, tokens in zip(labels, tokenized) + if all(token in w2v.wv for token in tokens) + ] + ) + self.labels = torch.tensor(labels, dtype=int) + self.tokenized = np.array(tokenized) + self.batches_per_epoch = batches_per_epoch + self.batch_size = batch_size + self.w2v = w2v + self.texts = set(" ".join(tokens[-cfg.target_len :]) for tokens in tokenized) + print( + "dataset random texts examples\n", + "\n".join([txt for txt in self.texts][:5]), + ) + + def vectorize_batch(self, tokenized): + vectors = [ + vectorize_sentence( + tokens, + self.w2v, + target_len=cfg.target_len, + padding_token=cfg.padding_token, + ) + for tokens in tokenized + ] + vectors = np.stack(vectors, axis=0) + vectors = torch.tensor(vectors, dtype=torch.float32) + return vectors + + def __iter__(self): + permutation = torch.randperm(len(self)) + self.tokenized = self.tokenized[permutation] + self.labels = self.labels[permutation] + + for _ in range(self.batches_per_epoch): + index = 0 + index += self.batch_size + if index > len(self): + # concatenating beginning of self.vectors + yield ( + torch.cat( + ( + self.labels[index - self.batch_size : index], + self.labels[: index - len(self)], + ) + ), + torch.cat( + ( + self.vectorize_batch( + self.tokenized[index - self.batch_size : index] + ), + self.vectorize_batch(self.tokenized[: index - len(self)]), + ) + ), + ) + index = index % len(self) + else: + yield self.labels[ + index - self.batch_size : index + ], self.vectorize_batch(self.tokenized[index - self.batch_size : index]) + + def __len__(self): + return self.batches_per_epoch + + def is_message_in_dataset(self, text): + return text in self.texts diff --git a/utils/data_utils.py b/utils/data_utils.py index 67b4bda3..164199fb 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -4,7 +4,7 @@ # @FileName : data_utils.py # @Time : Created at 2019-03-16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from time import strftime, localtime @@ -18,29 +18,45 @@ def create_multi_oracle(number): for i in range(number): - print('Creating Oracle %d...' % i) - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + print("Creating Oracle %d..." % i) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size) torch.save(oracle.state_dict(), cfg.multi_oracle_state_dict_path.format(i)) - torch.save(large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) - torch.save(small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2)) + torch.save( + large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num) + ) + torch.save( + small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2) + ) oracle_data = GenDataIter(large_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) - print('Oracle %d Groud Truth: %.4f' % (i, groud_truth)) + print("Oracle %d Groud Truth: %.4f" % (i, groud_truth)) -def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): +def create_specific_oracle(from_a, to_b, num=1, save_path="../pretrain/"): for i in range(num): while True: - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -52,24 +68,36 @@ def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: - dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format(groud_truth, - strftime("%m%d_%H%M%S", localtime())) + dir_path = save_path + "oracle_data_gt{:.2f}_{}".format( + groud_truth, strftime("%m%d_%H%M%S", localtime()) + ) if not os.path.exists(dir_path): os.mkdir(dir_path) - print('save ground truth: ', groud_truth) + print("save ground truth: ", groud_truth) # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime())) - prefix = dir_path + '/oracle_lstm' - torch.save(oracle.state_dict(), '{}.pt'.format(prefix)) - torch.save(big_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) - torch.save(small_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) + prefix = dir_path + "/oracle_lstm" + torch.save(oracle.state_dict(), "{}.pt".format(prefix)) + torch.save( + big_samples, "{}_samples_{}.pt".format(prefix, cfg.samples_num) + ) + torch.save( + small_samples, + "{}_samples_{}.pt".format(prefix, cfg.samples_num // 2), + ) break -def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): +def create_many_oracle(from_a, to_b, num=1, save_path="../pretrain/"): for i in range(num): while True: - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -81,33 +109,39 @@ def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: - print('save ground truth: ', groud_truth) - prefix = 'oracle_lstm' - torch.save(oracle.state_dict(), save_path + '{}.pt'.format(prefix)) - torch.save(big_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) - torch.save(small_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) + print("save ground truth: ", groud_truth) + prefix = "oracle_lstm" + torch.save(oracle.state_dict(), save_path + "{}.pt".format(prefix)) + torch.save( + big_samples, + save_path + "{}_samples_{}.pt".format(prefix, cfg.samples_num), + ) + torch.save( + small_samples, + save_path + "{}_samples_{}.pt".format(prefix, cfg.samples_num // 2), + ) break def _save(data, filename): - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for d in data: - fout.write(d['reviewText'] + '\n') - fout.write(str(d['overall']) + '\n') + fout.write(d["reviewText"] + "\n") + fout.write(str(d["overall"]) + "\n") def _count(filename): - with open(filename, 'r') as fin: - data = fin.read().strip().split('\n') + with open(filename, "r") as fin: + data = fin.read().strip().split("\n") return len(data) / 2 def clean_amazon_long_sentence(): - data_root = '/home/sysu2018/Documents/william/amazon_dataset/' + data_root = "/home/sysu2018/Documents/william/amazon_dataset/" all_files = os.listdir(data_root) - print('|\ttype\t|\torigin\t|\tclean_40\t|\tclean_20\t|\tfinal_40\t|\tfinal_20\t|') - print('|----------|----------|----------|----------|----------|----------|') + print("|\ttype\t|\torigin\t|\tclean_40\t|\tclean_20\t|\tfinal_40\t|\tfinal_20\t|") + print("|----------|----------|----------|----------|----------|----------|") for file in all_files: filename = data_root + file if os.path.isdir(filename): @@ -117,37 +151,44 @@ def clean_amazon_long_sentence(): clean_save_20 = [] final_save_40 = [] final_save_20 = [] - with open(filename, 'r') as fin: - raw_data = fin.read().strip().split('\n') + with open(filename, "r") as fin: + raw_data = fin.read().strip().split("\n") for line in raw_data: - review = eval(line)['reviewText'] + review = eval(line)["reviewText"] if len(review.split()) <= 40: clean_save_40.append(eval(line)) - if len(review.split('.')) <= 2: # one sentence + if len(review.split(".")) <= 2: # one sentence final_save_40.append(eval(line)) if len(review.split()) <= 20: clean_save_20.append(eval(line)) - if len(review.split('.')) <= 2: # one sentence + if len(review.split(".")) <= 2: # one sentence final_save_20.append(eval(line)) - save_filename = data_root + 'clean_40/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "clean_40/" + file.lower().split("_5")[0] + ".txt" _save(clean_save_40, save_filename) # a = _count(save_filename) - save_filename = data_root + 'clean_20/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "clean_20/" + file.lower().split("_5")[0] + ".txt" _save(clean_save_20, save_filename) # b = _count(save_filename) - save_filename = data_root + 'final_40/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "final_40/" + file.lower().split("_5")[0] + ".txt" _save(final_save_40, save_filename) # c = _count(save_filename) - save_filename = data_root + 'final_20/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "final_20/" + file.lower().split("_5")[0] + ".txt" _save(final_save_20, save_filename) # d = _count(save_filename) - print('|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|' % ( - file.lower().split('_5')[0], len(raw_data), - len(clean_save_40), len(clean_save_20), - len(final_save_40), len(final_save_20))) + print( + "|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|" + % ( + file.lower().split("_5")[0], + len(raw_data), + len(clean_save_40), + len(clean_save_20), + len(final_save_40), + len(final_save_20), + ) + ) # print('|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|' % ( # file.lower().split('_5')[0], len(raw_data), a, b, c, d)) @@ -163,5 +204,5 @@ def mean_list(x, y): return res -if __name__ == '__main__': +if __name__ == "__main__": pass diff --git a/utils/gan_loss.py b/utils/gan_loss.py index 32411660..01603bbb 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -4,13 +4,14 @@ # @FileName : gan_loss.py # @Time : Created at 2019-07-11 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import config as cfg +from utils.nn_helpers import DiversityLoss class GANLoss(nn.Module): @@ -20,8 +21,16 @@ class GANLoss(nn.Module): that has the same size as the input. """ - def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_fake_label=0.0, CUDA=False): - """ Initialize the GAN's Discriminator Loss class. + def __init__( + self, + loss_mode, + which_net, + which_D, + target_real_label=1.0, + target_fake_label=0.0, + CUDA=False, + ): + """Initialize the GAN's Discriminator Loss class. Parameters: loss_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. @@ -32,21 +41,25 @@ def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.register_buffer("real_label", torch.tensor(target_real_label)) + self.register_buffer("fake_label", torch.tensor(target_fake_label)) self.loss_mode = loss_mode self.which_net = which_net self.which_D = which_D self.gpu = CUDA - if loss_mode == 'lsgan': + if loss_mode == "lsgan": self.loss = nn.MSELoss() - elif loss_mode in ['vanilla', 'ragan', 'rsgan']: + elif loss_mode in ["vanilla", "ragan", "rsgan"]: self.loss = nn.BCEWithLogitsLoss() - elif loss_mode in ['wgan', 'hinge']: + elif loss_mode in ["wgan", "hinge"]: self.loss = None + elif loss_mode == "fixem": + self.real_fake_criterion = nn.BCEWithLogitsLoss() + self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + self.diversity_criterion = DiversityLoss() else: - raise NotImplementedError('gan mode %s not implemented' % loss_mode) + raise NotImplementedError("gan mode %s not implemented" % loss_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. @@ -65,84 +78,127 @@ def get_target_tensor(self, prediction, target_is_real): return target_tensor.expand_as(prediction) def G_loss(self, Dreal, Dfake): - if self.loss_mode != 'rsgan' and cfg.d_out_mean: + if self.loss_mode != "rsgan" and cfg.d_out_mean: Dfake = torch.mean(Dfake.view(cfg.batch_size, -1), dim=-1) Dreal = torch.mean(Dreal.view(cfg.batch_size, -1), dim=-1) real_tensor = self.get_target_tensor(Dreal, True) fake_tensor = self.get_target_tensor(Dreal, False) - if self.which_D == 'S': + if self.which_D == "S": prediction_fake = Dfake - prediction_real = real_tensor if self.loss_mode in ['vanilla'] else fake_tensor - elif self.which_D == 'Ra': + prediction_real = ( + real_tensor if self.loss_mode in ["vanilla"] else fake_tensor + ) + elif self.which_D == "Ra": prediction_fake = Dfake - torch.mean(Dreal) prediction_real = Dreal - torch.mean(Dfake) else: - raise NotImplementedError('which_D name [%s] is not recognized' % self.which_D) + raise NotImplementedError( + "which_D name [%s] is not recognized" % self.which_D + ) - if self.loss_mode in ['lsgan', 'ragan']: + if self.loss_mode in ["lsgan", "ragan"]: loss_fake = self.loss(prediction_fake, real_tensor) loss_real = self.loss(prediction_real, fake_tensor) g_loss = loss_fake + loss_real - elif self.loss_mode == 'vanilla': + elif self.loss_mode == "vanilla": loss_fake = -self.loss(prediction_fake, fake_tensor) g_loss = loss_fake - elif self.loss_mode in ['wgan', 'hinge'] and self.which_D == 'S': + elif self.loss_mode in ["wgan", "hinge"] and self.which_D == "S": loss_fake = -prediction_fake.mean() loss_real = prediction_real.mean() g_loss = loss_fake + loss_real - elif self.loss_mode == 'hinge' and self.which_D == 'Ra': + elif self.loss_mode == "hinge" and self.which_D == "Ra": loss_fake = nn.ReLU()(1.0 - prediction_fake).mean() loss_real = nn.ReLU()(1.0 + prediction_real).mean() g_loss = loss_fake + loss_real - elif self.loss_mode == 'rsgan': + elif self.loss_mode == "rsgan": loss_fake = self.loss(Dfake - Dreal, real_tensor) g_loss = loss_fake else: - raise NotImplementedError('loss_mode name [%s] is not recognized' % self.loss_mode) + raise NotImplementedError( + "loss_mode name [%s] is not recognized" % self.loss_mode + ) return g_loss def D_loss(self, Dreal, Dfake): - if self.loss_mode != 'rsgan' and cfg.d_out_mean: + if self.loss_mode != "rsgan" and cfg.d_out_mean: Dfake = torch.mean(Dfake.view(cfg.batch_size, -1), dim=-1) Dreal = torch.mean(Dreal.view(cfg.batch_size, -1), dim=-1) real_tensor = self.get_target_tensor(Dreal, True) fake_tensor = self.get_target_tensor(Dreal, False) - if self.which_D == 'S': + if self.which_D == "S": prediction_fake = Dfake prediction_real = Dreal - elif self.which_D == 'Ra': + elif self.which_D == "Ra": prediction_fake = Dfake - torch.mean(Dreal) prediction_real = Dreal - torch.mean(Dfake) else: - raise NotImplementedError('which_D name [%s] is not recognized' % self.which_D) + raise NotImplementedError( + "which_D name [%s] is not recognized" % self.which_D + ) - if self.loss_mode in ['lsgan', 'ragan', 'vanilla']: + if self.loss_mode in ["lsgan", "ragan", "vanilla"]: loss_fake = self.loss(prediction_fake, fake_tensor) loss_real = self.loss(prediction_real, real_tensor) - elif self.loss_mode == 'wgan': + elif self.loss_mode == "wgan": loss_fake = prediction_fake.mean() loss_real = -prediction_real.mean() - elif self.loss_mode == 'hinge': + elif self.loss_mode == "hinge": loss_fake = nn.ReLU()(1.0 + prediction_fake).mean() loss_real = nn.ReLU()(1.0 - prediction_real).mean() - elif self.loss_mode == 'rsgan': - loss_fake = 0. + elif self.loss_mode == "rsgan": + loss_fake = 0.0 loss_real = self.loss(Dreal - Dfake, real_tensor) else: - raise NotImplementedError('loss_mode name [%s] is not recognized' % self.loss_mode) + raise NotImplementedError( + "loss_mode name [%s] is not recognized" % self.loss_mode + ) return loss_fake + loss_real + def G_loss_fixem(self, real_fake_predicts, label_predicts, target_labels, fakes): + target_fake = self.get_target_tensor(real_fake_predicts, target_is_real=True) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion( + real_fake_predicts, target_fake + ) + labels_loss = cfg.labels_coeff * self.label_criterion( + label_predicts, target_labels + ) + diversity_loss = cfg.diversity_coeff * self.diversity_criterion(fakes) + loss = real_fake_loss + diversity_loss + loss = loss + labels_loss if cfg.run_model == "cat_fixemgan" else loss + return loss + + def D_loss_fixem(self, real_fake_predicts, label_predicts, target_labels): + target_real = self.get_target_tensor( + real_fake_predicts.chunk(2)[0], target_is_real=True + ) + target_fake = self.get_target_tensor( + real_fake_predicts.chunk(2)[1], target_is_real=False + ) + target_real_fake = torch.cat((target_real, target_fake)) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion( + real_fake_predicts, target_real_fake + ) + labels_loss = cfg.labels_coeff * self.label_criterion( + label_predicts, target_labels + ) + loss = real_fake_loss + loss = loss + labels_loss if cfg.run_model == "cat_fixemgan" else loss + return loss + def __call__(self, Dreal, Dfake): """Calculate loss given Discriminator's output and grount truth labels.""" - if self.which_net == 'G': + if self.which_net == "G": return self.G_loss(Dreal, Dfake) - elif self.which_net == 'D': + elif self.which_net == "D": return self.D_loss(Dreal, Dfake) else: - raise NotImplementedError('which_net name [%s] is not recognized' % self.which_net) + raise NotImplementedError( + "which_net name [%s] is not recognized" % self.which_net + ) diff --git a/utils/helpers.py b/utils/helpers.py index a7abfe10..42e4b5ef 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn as nn +from tqdm import tqdm from metrics.nll import NLL from utils.data_loader import GenDataIter @@ -22,11 +23,11 @@ def __init__(self, signal_file): def update(self): signal_dict = self.read_signal() - self.pre_sig = signal_dict['pre_sig'] - self.adv_sig = signal_dict['adv_sig'] + self.pre_sig = signal_dict["pre_sig"] + self.adv_sig = signal_dict["adv_sig"] def read_signal(self): - with open(self.signal_file, 'r') as fin: + with open(self.signal_file, "r") as fin: return eval(fin.read()) @@ -36,22 +37,26 @@ def create_logger(name, silent=False, to_disk=False, log_file=None): log = logging.getLogger(name) log.setLevel(logging.DEBUG) log.propagate = False - formatter = logging.Formatter(fmt='%(message)s', datefmt='%Y/%m/%d %I:%M:%S') + formatter = logging.Formatter(fmt="%(message)s", datefmt="%Y/%m/%d %I:%M:%S") if not silent: ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) log.addHandler(ch) if to_disk: - log_file = log_file if log_file is not None else strftime("log/log_%m%d_%H%M.txt", gmtime()) + log_file = ( + log_file + if log_file is not None + else strftime("log/log_%m%d_%H%M.txt", gmtime()) + ) if type(log_file) == list: for filename in log_file: - fh = logging.FileHandler(filename, mode='w') + fh = logging.FileHandler(filename, mode="w") fh.setLevel(logging.INFO) fh.setFormatter(formatter) log.addHandler(fh) if type(log_file) == str: - fh = logging.FileHandler(log_file, mode='w') + fh = logging.FileHandler(log_file, mode="w") fh.setLevel(logging.INFO) fh.setFormatter(formatter) log.addHandler(fh) @@ -63,9 +68,15 @@ def create_oracle(): import config as cfg from models.Oracle import Oracle - print('Creating Oracle...') - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + print("Creating Oracle...") + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -75,32 +86,45 @@ def create_oracle(): # large torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num)) # small - torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), - cfg.oracle_samples_path.format(cfg.samples_num // 2)) + torch.save( + oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), + cfg.oracle_samples_path.format(cfg.samples_num // 2), + ) + + # giant for W2V + giant_samples = oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) + with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), "w") as f: + for sample in tqdm(giant_samples): + f.write(" ".join(str(int(idx)) for idx in sample)) + f.write("\n") oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) - print('NLL_Oracle Groud Truth: %.4f' % groud_truth) + print("NLL_Oracle Groud Truth: %.4f" % groud_truth) def get_fixed_temperature(temper, i, N, adapt): """A function to set up different temperature control policies""" N = 5000 - if adapt == 'no': + if adapt == "no": temper_var_np = 1.0 # no increase, origin: temper - elif adapt == 'lin': + elif adapt == "lin": temper_var_np = 1 + i / (N - 1) * (temper - 1) # linear increase - elif adapt == 'exp': + elif adapt == "exp": temper_var_np = temper ** (i / N) # exponential increase - elif adapt == 'log': - temper_var_np = 1 + (temper - 1) / np.log(N) * np.log(i + 1) # logarithm increase - elif adapt == 'sigmoid': - temper_var_np = (temper - 1) * 1 / (1 + np.exp((N / 2 - i) * 20 / N)) + 1 # sigmoid increase - elif adapt == 'quad': - temper_var_np = (temper - 1) / (N - 1) ** 2 * i ** 2 + 1 - elif adapt == 'sqrt': + elif adapt == "log": + temper_var_np = 1 + (temper - 1) / np.log(N) * np.log( + i + 1 + ) # logarithm increase + elif adapt == "sigmoid": + temper_var_np = (temper - 1) * 1 / ( + 1 + np.exp((N / 2 - i) * 20 / N) + ) + 1 # sigmoid increase + elif adapt == "quad": + temper_var_np = (temper - 1) / (N - 1) ** 2 * i**2 + 1 + elif adapt == "sqrt": temper_var_np = (temper - 1) / np.sqrt(N - 1) * np.sqrt(i) + 1 else: raise Exception("Unknown adapt type!") @@ -108,43 +132,43 @@ def get_fixed_temperature(temper, i, N, adapt): return temper_var_np -def get_losses(d_out_real, d_out_fake, loss_type='JS'): +def get_losses(d_out_real, d_out_fake, loss_type="JS"): """Get different adversarial losses according to given loss_type""" bce_loss = nn.BCEWithLogitsLoss() - if loss_type == 'standard': # the non-satuating GAN loss + if loss_type == "standard": # the non-satuating GAN loss d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = bce_loss(d_out_fake, torch.ones_like(d_out_fake)) - elif loss_type == 'JS': # the vanilla GAN loss + elif loss_type == "JS": # the vanilla GAN loss d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = -d_loss_fake - elif loss_type == 'KL': # the GAN loss implicitly minimizing KL-divergence + elif loss_type == "KL": # the GAN loss implicitly minimizing KL-divergence d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = torch.mean(-d_out_fake) - elif loss_type == 'hinge': # the hinge loss + elif loss_type == "hinge": # the hinge loss d_loss_real = torch.mean(nn.ReLU(1.0 - d_out_real)) d_loss_fake = torch.mean(nn.ReLU(1.0 + d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = -torch.mean(d_out_fake) - elif loss_type == 'tv': # the total variation distance + elif loss_type == "tv": # the total variation distance d_loss = torch.mean(nn.Tanh(d_out_fake) - nn.Tanh(d_out_real)) g_loss = torch.mean(-nn.Tanh(d_out_fake)) - elif loss_type == 'rsgan': # relativistic standard GAN + elif loss_type == "rsgan": # relativistic standard GAN d_loss = bce_loss(d_out_real - d_out_fake, torch.ones_like(d_out_real)) g_loss = bce_loss(d_out_fake - d_out_real, torch.ones_like(d_out_fake)) diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py new file mode 100644 index 00000000..6dde8240 --- /dev/null +++ b/utils/nn_helpers.py @@ -0,0 +1,283 @@ +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +import torch.optim as optim +from torch.distributions.uniform import Uniform + +import math + +import matplotlib.pyplot as plt + + +def create_noise(sample_size, noise_size, k_label): + return ( + torch.randn(sample_size, noise_size), + torch.randint(0, k_label, (sample_size,)), + ) + + +def multiply_shape(shape): + if len(shape) == 1: + return shape[0] + return shape[0] * multiply_shape(shape[1:]) + + +def number_of_parameters(parameters): + nb_of_vars = 0 + for parameter in parameters: + nb_of_vars += multiply_shape(tuple(parameter.shape)) + return nb_of_vars + + +def get_optimizer(parameters, lr=0.0001, betas=(0.5, 0.999)): + return optim.Adam(parameters, lr=lr, betas=betas) + + +class PositionalEncoding(nn.Module): + def __init__(self, dim_pe: int, max_len: int, concatenate_pe=False): + super().__init__() + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, dim_pe, 2) * (-math.log(dim_pe // 4) / dim_pe) + ) + pe = torch.zeros(max_len, 1, dim_pe) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + pe = torch.transpose(pe, 0, 1) + pe = torch.transpose(pe, 1, 2) + # plt.imshow(pe[0], cmap="hot", interpolation="nearest") + # plt.show() + self.register_buffer("pe", pe) + self.concatenate_pe = concatenate_pe + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + pe = self.pe.repeat(x.size(0), 1, 1) + # input (N,C,L) - N bathc size, C - channels, L - length + return torch.cat((x, pe), 1) if self.concatenate_pe else x + pe + + +class Reshape(nn.Module): + def __init__(self, *out_shape): + super(Reshape, self).__init__() + self.out_shape = out_shape + + def forward(self, input_batch): + """Turch batched flat vector to out_shape""" + return torch.reshape(input_batch, (-1, *self.out_shape)) + + +class Concatenate(nn.Module): + def __init__(self, dim): + super(Concatenate, self).__init__() + self.dim = dim + + def forward(self, input_batch): + return torch.cat(input_batch, dim=1) + + +class Dummy(nn.Module): + """For shadowing some layers.""" + + def __init__(self): + super(Dummy, self).__init__() + + def forward(self, x): + return x + + +class MyTransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead=1, n_layers=1): + super(MyTransformerEncoderLayer, self).__init__() + self.tranformer_layers = nn.Sequential( + *tuple( + nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, batch_first=True + ) + for _ in range(n_layers) + ) + ) + + def forward(self, x): + x = self.tranformer_layers(torch.transpose(x, 1, 2)) + return torch.transpose(x, 1, 2) + + +class MyConvTransposeLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + output_padding=0, + alpha=0.2, + include_batch_norm=True, + padding=1, + kernel_size=3, + ): + super(MyConvTransposeLayer, self).__init__() + self.conv_layer = nn.Sequential( + nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + nn.LeakyReLU(alpha), + nn.Conv1d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ), + nn.BatchNorm1d(out_channels) if include_batch_norm else Dummy(), + nn.LeakyReLU(alpha), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyConvLayerNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + output_padding=0, + alpha=0.2, + include_batch_norm=True, + padding=1, + kernel_size=3, + ): + super(MyConvLayerNorm, self).__init__() + self.conv_layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + nn.BatchNorm1d(out_channels) if include_batch_norm else Dummy(), + nn.LeakyReLU(alpha), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyLSTMLayerNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + alpha=0.2, + ): + super(MyLSTMLayerNorm, self).__init__() + self.lstm = nn.LSTM( + batch_first=True, + bidirectional=True, + input_size=in_channels, + hidden_size=out_channels, + ) + self.layers = nn.Sequential( + nn.BatchNorm1d(2 * out_channels), + nn.LeakyReLU(alpha), + ) + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x, (hn, cn) = self.lstm(x) + x = torch.transpose(x, 1, 2) + x = self.layers(x) + return x + + +class Flatten(nn.Module): + def __init__(self): + super(Flatten, self).__init__() + self.start_dim = 1 + + def forward(self, input_tensor): + return torch.flatten(input_tensor, start_dim=1) + + +class MyConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding="same", + alpha=0.2, + drop_rate=0.2, + ): + super(MyConvLayer, self).__init__() + self.conv_layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyLSTMLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + alpha=0.2, + drop_rate=0.2, + ): + super(MyLSTMLayer, self).__init__() + self.lstm = nn.LSTM( + batch_first=True, + bidirectional=True, + input_size=in_channels, + hidden_size=out_channels, + ) + self.layers = nn.Sequential( + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x, (hn, cn) = self.lstm(x) + x = torch.transpose(x, 1, 2) + x = self.layers(x) + return x + + +def DiversityLoss(): + cs2 = torch.nn.CosineSimilarity(dim=2) + + def cos_sim_loss(generated): + batch_size = generated.shape[0] + generated = generated.repeat(batch_size, 1, 1, 1) + generatedTranspose = torch.transpose(generated, 0, 1) + loss = cs2(generated, generatedTranspose) + ind = np.diag_indices(loss.shape[0]) + loss[ind[0], ind[1], :] = 0 # set 0 to similarity of message to itself + loss = loss.mean(axis=2).max(axis=0).values.mean() + return loss + + return cos_sim_loss diff --git a/utils/rollout.py b/utils/rollout.py index 6e165b21..96c46e19 100644 --- a/utils/rollout.py +++ b/utils/rollout.py @@ -4,7 +4,7 @@ # @FileName : rollout.py # @Time : Created at 2019-03-15 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy @@ -18,8 +18,8 @@ def __init__(self, gen, gpu=True): self.old_model = copy.deepcopy(gen) self.max_seq_len = gen.max_seq_len self.vocab_size = gen.vocab_size - self.step_size = gen.step_size if gen.name == 'leakgan' else 0 - self.goal_out_size = gen.goal_out_size if gen.name == 'leakgan' else 0 + self.step_size = gen.step_size if gen.name == "leakgan" else 0 + self.goal_out_size = gen.goal_out_size if gen.name == "leakgan" else 0 self.gpu = gpu def rollout_mc_search(self, sentences, given_num): @@ -73,7 +73,7 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): for i in range(given_num): # Get feature. dis_inp = torch.zeros(batch_size, seq_len).long() - dis_inp[:, :i + 1] = sentences[:, :i + 1] # cut sentences + dis_inp[:, : i + 1] = sentences[:, : i + 1] # cut sentences leak_inp = sentences[:, i] if self.gpu: dis_inp = dis_inp.cuda() @@ -82,13 +82,14 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.gen(i, leak_inp, work_hidden, mana_hidden, - feature, real_goal, train=True) + out, cur_goal, work_hidden, mana_hidden = self.gen( + i, leak_inp, work_hidden, mana_hidden, feature, real_goal, train=True + ) # Save goal and update last_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.gen.goal_init[:batch_size, :] @@ -98,7 +99,9 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # MC search for i in range(given_num, self.max_seq_len): # Sample one token - out = torch.multinomial(torch.exp(out), 1).view(-1) # [num_samples] (sampling from each row) + out = torch.multinomial(torch.exp(out), 1).view( + -1 + ) # [num_samples] (sampling from each row) samples[:, i] = out.data # Get feature @@ -110,13 +113,14 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.gen(i, leak_inp, work_hidden, mana_hidden, - feature, real_goal, train=True) + out, cur_goal, work_hidden, mana_hidden = self.gen( + i, leak_inp, work_hidden, mana_hidden, feature, real_goal, train=True + ) # Save goal and update last_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.gen.goal_init[:batch_size, :] @@ -150,7 +154,9 @@ def get_reward(self, sentences, rollout_num, dis, current_k=0): idx += 1 # rewards = torch.mean(rewards, dim=0) - rewards = torch.mean(rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1) + rewards = torch.mean( + rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1 + ) return rewards def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): @@ -165,7 +171,9 @@ def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): """ with torch.no_grad(): batch_size = sentences.size(0) - rewards = torch.zeros([rollout_num * (self.max_seq_len // self.step_size), batch_size]).float() + rewards = torch.zeros( + [rollout_num * (self.max_seq_len // self.step_size), batch_size] + ).float() if self.gpu: rewards = rewards.cuda() idx = 0 @@ -179,7 +187,9 @@ def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): rewards[idx] = reward idx += 1 - rewards = rewards.view(batch_size, self.max_seq_len // self.step_size, rollout_num) + rewards = rewards.view( + batch_size, self.max_seq_len // self.step_size, rollout_num + ) rewards = torch.mean(rewards, dim=-1) return rewards diff --git a/utils/text_process.py b/utils/text_process.py index 7bdb9daf..889872fe 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -4,23 +4,30 @@ # @FileName : text_process.py # @Time : Created at 2019-05-14 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. +from typing import Dict, Optional -import nltk import numpy as np import os import torch +from tqdm import tqdm import config as cfg +def text_file_iterator(file): + with open(file) as raw: + for line in raw.readlines(): + yield line.strip("\n").split() + + def get_tokenlized(file): """tokenlize the file""" tokenlized = list() with open(file) as raw: for text in raw: - text = nltk.word_tokenize(text.lower()) + text = text.strip("\n").lower().split() tokenlized.append(text) return tokenlized @@ -65,7 +72,9 @@ def text_process(train_text_loc, test_text_loc=None): if test_text_loc is None: sequence_len = len(max(train_tokens, key=len)) else: - sequence_len = max(len(max(train_tokens, key=len)), len(max(test_tokens, key=len))) + sequence_len = max( + len(max(train_tokens, key=len)), len(max(test_tokens, key=len)) + ) return sequence_len, len(word2idx_dict) @@ -76,29 +85,31 @@ def init_dict(dataset): Initialize dictionaries of dataset, please note that '0': padding_idx, '1': start_letter. Finally save dictionary files locally. """ - tokens = get_tokenlized('dataset/{}.txt'.format(dataset)) + tokens = get_tokenlized("dataset/{}.txt".format(dataset)) word_set = get_word_list(tokens) word2idx_dict, idx2word_dict = get_dict(word_set) - with open('dataset/{}_wi_dict.txt'.format(dataset), 'w') as dictout: + with open("dataset/{}_wi_dict.txt".format(dataset), "w") as dictout: dictout.write(str(word2idx_dict)) - with open('dataset/{}_iw_dict.txt'.format(dataset), 'w') as dictout: + with open("dataset/{}_iw_dict.txt".format(dataset), "w") as dictout: dictout.write(str(idx2word_dict)) - print('total tokens: ', len(word2idx_dict)) + print("total tokens: ", len(word2idx_dict)) def load_dict(dataset): """Load dictionary from local files""" - iw_path = 'dataset/{}_iw_dict.txt'.format(dataset) - wi_path = 'dataset/{}_wi_dict.txt'.format(dataset) + iw_path = "dataset/{}_iw_dict.txt".format(dataset) + wi_path = "dataset/{}_wi_dict.txt".format(dataset) - if not os.path.exists(iw_path) or not os.path.exists(iw_path): # initialize dictionaries + if not os.path.exists(iw_path) or not os.path.exists( + iw_path + ): # initialize dictionaries init_dict(dataset) - with open(iw_path, 'r') as dictin: + with open(iw_path, "r") as dictin: idx2word_dict = eval(dictin.read().strip()) - with open(wi_path, 'r') as dictin: + with open(wi_path, "r") as dictin: word2idx_dict = eval(dictin.read().strip()) return word2idx_dict, idx2word_dict @@ -108,7 +119,7 @@ def load_test_dict(dataset): """Build test data dictionary, extend from train data. For the classifier.""" word2idx_dict, idx2word_dict = load_dict(dataset) # train dict # tokens = get_tokenlized('dataset/testdata/{}_clas_test.txt'.format(dataset)) - tokens = get_tokenlized('dataset/testdata/{}_test.txt'.format(dataset)) + tokens = get_tokenlized("dataset/testdata/{}_test.txt".format(dataset)) word_set = get_word_list(tokens) index = len(word2idx_dict) # current index @@ -121,7 +132,7 @@ def load_test_dict(dataset): return word2idx_dict, idx2word_dict -def tensor_to_tokens(tensor, dictionary): +def tensor_to_tokens(tensor, dictionary: Optional[Dict[torch.Tensor, str]] = None): """transform Tensor to word tokens""" tokens = [] for sent in tensor: @@ -129,7 +140,10 @@ def tensor_to_tokens(tensor, dictionary): for word in sent.tolist(): if word == cfg.padding_idx: break - sent_token.append(dictionary[str(word)]) + word = str(word) + if dictionary: + word = dictionary[word] + sent_token.append(word) tokens.append(sent_token) return tokens @@ -147,7 +161,7 @@ def tokens_to_tensor(tokens, dictionary): while i < cfg.max_seq_len - 1: sent_ten.append(cfg.padding_idx) i += 1 - tensor.append(sent_ten[:cfg.max_seq_len]) + tensor.append(sent_ten[: cfg.max_seq_len]) return torch.LongTensor(tensor) @@ -170,32 +184,32 @@ def padding_token(tokens): def write_tokens(filename, tokens): """Write word tokens to a local file (For Real data)""" - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for sent in tokens: - fout.write(' '.join(sent)) - fout.write('\n') + fout.write(" ".join(sent)) + fout.write("\n") def write_tensor(filename, tensor): """Write Tensor to a local file (For Oracle data)""" - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for sent in tensor: - fout.write(' '.join([str(i) for i in sent.tolist()])) - fout.write('\n') + fout.write(" ".join([str(i) for i in sent.tolist()])) + fout.write("\n") def process_cat_text(): import random - dataset = 'mr' + dataset = "mr" test_ratio = 0.3 seq_len = 15 - pos_file = 'dataset/{}/{}{}_cat1.txt'.format(dataset, dataset, seq_len) - neg_file = 'dataset/{}/{}{}_cat0.txt'.format(dataset, dataset, seq_len) - pos_sent = open(pos_file, 'r').readlines() - neg_sent = open(neg_file, 'r').readlines() + pos_file = "dataset/{}/{}{}_cat1.txt".format(dataset, dataset, seq_len) + neg_file = "dataset/{}/{}{}_cat0.txt".format(dataset, dataset, seq_len) + pos_sent = open(pos_file, "r").readlines() + neg_sent = open(neg_file, "r").readlines() pos_len = int(test_ratio * len(pos_sent)) neg_len = int(test_ratio * len(neg_sent)) @@ -208,10 +222,14 @@ def process_cat_text(): random.shuffle(all_sent_test) random.shuffle(all_sent_train) - f_pos_train = open('dataset/{}{}_cat1.txt'.format(dataset, seq_len), 'w') - f_neg_train = open('dataset/{}{}_cat0.txt'.format(dataset, seq_len), 'w') - f_pos_test = open('dataset/testdata/{}{}_cat1_test.txt'.format(dataset, seq_len), 'w') - f_neg_test = open('dataset/testdata/{}{}_cat0_test.txt'.format(dataset, seq_len), 'w') + f_pos_train = open("dataset/{}{}_cat1.txt".format(dataset, seq_len), "w") + f_neg_train = open("dataset/{}{}_cat0.txt".format(dataset, seq_len), "w") + f_pos_test = open( + "dataset/testdata/{}{}_cat1_test.txt".format(dataset, seq_len), "w" + ) + f_neg_test = open( + "dataset/testdata/{}{}_cat0_test.txt".format(dataset, seq_len), "w" + ) for p_s in pos_sent[:pos_len]: f_pos_test.write(p_s) @@ -222,10 +240,10 @@ def process_cat_text(): for n_s in neg_sent[neg_len:]: f_neg_train.write(n_s) - with open('dataset/testdata/{}{}_test.txt'.format(dataset, seq_len), 'w') as fout: + with open("dataset/testdata/{}{}_test.txt".format(dataset, seq_len), "w") as fout: for sent in all_sent_test: fout.write(sent) - with open('dataset/{}{}.txt'.format(dataset, seq_len), 'w') as fout: + with open("dataset/{}{}.txt".format(dataset, seq_len), "w") as fout: for sent in all_sent_train: fout.write(sent) @@ -236,20 +254,22 @@ def process_cat_text(): def combine_amazon_text(): - cat0_name = 'app' - cat1_name = 'book' - root_path = 'dataset/' - cat0_train = open(root_path + cat0_name + '.txt', 'r').readlines() - cat0_test = open(root_path + cat0_name + '_test.txt', 'r').readlines() - cat1_train = open(root_path + cat1_name + '.txt', 'r').readlines() - cat1_test = open(root_path + cat1_name + '_test.txt', 'r').readlines() - - with open(root_path + 'amazon_{}_{}.txt'.format(cat0_name, cat1_name), 'w') as fout: + cat0_name = "app" + cat1_name = "book" + root_path = "dataset/" + cat0_train = open(root_path + cat0_name + ".txt", "r").readlines() + cat0_test = open(root_path + cat0_name + "_test.txt", "r").readlines() + cat1_train = open(root_path + cat1_name + ".txt", "r").readlines() + cat1_test = open(root_path + cat1_name + "_test.txt", "r").readlines() + + with open(root_path + "amazon_{}_{}.txt".format(cat0_name, cat1_name), "w") as fout: for sent in cat0_train: fout.write(sent) for sent in cat1_train: fout.write(sent) - with open(root_path + 'testdata/amazon_{}_{}_test.txt'.format(cat0_name, cat1_name), 'w') as fout: + with open( + root_path + "testdata/amazon_{}_{}_test.txt".format(cat0_name, cat1_name), "w" + ) as fout: for sent in cat0_test: fout.write(sent) for sent in cat1_test: @@ -257,21 +277,23 @@ def combine_amazon_text(): def extend_clas_train_data(): - data_name = 'mr' - dataset = 'mr20' - neg_filter_file = 'dataset/{}/{}_cat0.txt'.format(data_name, dataset) # include train and test for generator - pos_filter_file = 'dataset/{}/{}_cat1.txt'.format(data_name, dataset) - neg_test_file = 'dataset/testdata/{}_cat0_test.txt'.format(dataset) - pos_test_file = 'dataset/testdata/{}_cat1_test.txt'.format(dataset) - neg_all_file = 'dataset/{}/{}_cat0.txt'.format(data_name, data_name) - pos_all_file = 'dataset/{}/{}_cat1.txt'.format(data_name, data_name) - - neg_filter = open(neg_filter_file, 'r').readlines() - pos_filter = open(pos_filter_file, 'r').readlines() - neg_test = open(neg_test_file, 'r').readlines() - pos_test = open(pos_test_file, 'r').readlines() - neg_all = open(neg_all_file, 'r').readlines() - pos_all = open(pos_all_file, 'r').readlines() + data_name = "mr" + dataset = "mr20" + neg_filter_file = "dataset/{}/{}_cat0.txt".format( + data_name, dataset + ) # include train and test for generator + pos_filter_file = "dataset/{}/{}_cat1.txt".format(data_name, dataset) + neg_test_file = "dataset/testdata/{}_cat0_test.txt".format(dataset) + pos_test_file = "dataset/testdata/{}_cat1_test.txt".format(dataset) + neg_all_file = "dataset/{}/{}_cat0.txt".format(data_name, data_name) + pos_all_file = "dataset/{}/{}_cat1.txt".format(data_name, data_name) + + neg_filter = open(neg_filter_file, "r").readlines() + pos_filter = open(pos_filter_file, "r").readlines() + neg_test = open(neg_test_file, "r").readlines() + pos_test = open(pos_test_file, "r").readlines() + neg_all = open(neg_all_file, "r").readlines() + pos_all = open(pos_all_file, "r").readlines() # print('neg filter:', len(neg_filter)) # print('neg test:', len(neg_test)) @@ -280,62 +302,67 @@ def extend_clas_train_data(): # print('pos test:', len(pos_test)) # print('pos all:', len(pos_all)) - print('neg before:', len(neg_test)) + print("neg before:", len(neg_test)) for line in neg_all: if line not in neg_filter: neg_test.append(line) - print('neg after:', len(neg_test)) + print("neg after:", len(neg_test)) - print('pos before:', len(pos_test)) + print("pos before:", len(pos_test)) for line in pos_all: if line not in pos_filter: pos_test.append(line) - print('pos after:', len(pos_test)) + print("pos after:", len(pos_test)) - with open('dataset/testdata/{}_cat0_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_cat0_clas_test.txt".format(dataset), "w") as fout: for line in neg_test: fout.write(line) - with open('dataset/testdata/{}_cat1_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_cat1_clas_test.txt".format(dataset), "w") as fout: for line in pos_test: fout.write(line) - with open('dataset/testdata/{}_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_clas_test.txt".format(dataset), "w") as fout: for line in neg_test: fout.write(line) for line in pos_test: fout.write(line) -def load_word_vec(path, word2idx_dict=None, type='glove'): +def load_word_vec(path, word2idx_dict=None, type="glove"): """Load word embedding from local file""" - fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') - if type == 'glove': + fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore") + if type == "glove": word2vec_dict = {} for line in fin: tokens = line.rstrip().split() if word2idx_dict is None or tokens[0] in word2idx_dict.keys(): - word2vec_dict[tokens[0]] = np.asarray(tokens[1:], dtype='float32') - elif type == 'word2vec': + word2vec_dict[tokens[0]] = np.asarray(tokens[1:], dtype="float32") + elif type == "word2vec": import gensim - word2vec_dict = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True) + + word2vec_dict = gensim.models.KeyedVectors.load_word2vec_format( + path, binary=True + ) else: - raise NotImplementedError('No such type: %s' % type) + raise NotImplementedError("No such type: %s" % type) return word2vec_dict def build_embedding_matrix(dataset): """Load or build Glove embedding matrix.""" - embed_filename = 'dataset/glove_embedding_300d_{}.pt'.format(dataset) + embed_filename = "dataset/glove_embedding_300d_{}.pt".format(dataset) if os.path.exists(embed_filename): - print('Loading embedding:', embed_filename) + print("Loading embedding:", embed_filename) embedding_matrix = torch.load(embed_filename) else: - print('Loading Glove word vectors...') + print("Loading Glove word vectors...") word2idx_dict, _ = load_dict(dataset) - embedding_matrix = np.random.random((len(word2idx_dict) + 2, 300)) # 2 for padding token and start token - fname = '../glove.42B.300d.txt' # Glove file + embedding_matrix = np.random.random( + (len(word2idx_dict) + 2, 300) + ) # 2 for padding token and start token + fname = "../glove.42B.300d.txt" # Glove file # fname = '../GoogleNews-vectors-negative300.bin' # Google Word2Vec file - word2vec_dict = load_word_vec(fname, word2idx_dict=word2idx_dict, type='glove') - print('Building embedding matrix:', embed_filename) + word2vec_dict = load_word_vec(fname, word2idx_dict=word2idx_dict, type="glove") + print("Building embedding matrix:", embed_filename) for word, i in word2idx_dict.items(): if word in word2vec_dict: # words not found in embedding index will be randomly initialized. @@ -345,9 +372,87 @@ def build_embedding_matrix(dataset): return embedding_matrix -if __name__ == '__main__': - os.chdir('../') - # process_cat_text() - # load_test_dict('mr15') - # extend_clas_train_data() - pass +def pad_sequences( + sequence, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None +) -> np.array: + sequence = np.array(sequence) + current_length = sequence.shape[0] + + if current_length >= target_len: + return sequence[-target_len:] + + padding = ( + np.repeat( + np.array([w2v.wv[padding_token]]), target_len - current_length, axis=0 + ) + if padding_token + else np.zeros((target_len - current_length, embedding_size)) + ) + return np.concatenate((padding, sequence), axis=0) + + +def vectorize_sentence( + tokens, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None +): + vectorized = pad_sequences( + [w2v.wv[token] for token in tokens], + w2v, + target_len=target_len, + embedding_size=embedding_size, + padding_token=padding_token, + ) + vectorized = vectorized.T # required for pytorch + return vectorized + + +# if __name__ == '__main__': +# os.chdir('../') +# # process_cat_text() +# # load_test_dict('mr15') +# # extend_clas_train_data() + +# # dataset preprocess and saving +# import torchtext +# import os +# import nltk +# nltk.download('punkt') +# from tqdm.notebook import tqdm +# from pathlib import Path + +# def tokenize_and_save(source, path, filename): +# with open(Path(path) / filename, 'w') as f: +# for _, line in tqdm(source, desc=filename): +# line = line.strip().lower() +# line = ' '.join(nltk.tokenize.word_tokenize(line)) +# line = ' '.join(line.split('\n')) +# line = ' '.join(line.split('\\n')) +# line = ' '.join(line.split('\\')) +# f.write(line) +# f.write('\n') + +# AGNEWS_train, AGNEWS_test = torchtext.datasets.AG_NEWS( +# root="./data", split=("train", "test") +# ) +# DBpedia_train, DBpedia_test = torchtext.datasets.DBpedia( +# root="./data", split=("train", "test") +# ) +# WikiText103_train, WikiText103_valid, WikiText103_test = torchtext.datasets.WikiText103( +# root="./data", split=("train", "valid", "test") +# ) +# YahooAnswers_train, YahooAnswers_test = torchtext.datasets.YahooAnswers( +# root="./data", split=("train", "test") +# ) +# YelpReviewFull_train, YelpReviewFull_test = torchtext.datasets.YelpReviewFull( +# root="./data", split=("train", "test") +# ) +# tokenize_and_save(AGNEWS_train, './dataset/', 'agnews_train.txt') +# tokenize_and_save(AGNEWS_test, './dataset/testdata/', 'agnews_test.txt') +# tokenize_and_save(DBpedia_train, './dataset/', 'dbpedia_train.txt') +# tokenize_and_save(DBpedia_test, './dataset/testdata/', 'dbpedia_test.txt') +# tokenize_and_save(enumerate(WikiText103_train), './dataset/', 'wikitext103_train.txt') +# tokenize_and_save(enumerate(WikiText103_valid), './dataset/', 'wikitext103_valid.txt') +# tokenize_and_save(enumerate(WikiText103_test), './dataset/testdata/', 'wikitext103_test.txt') +# tokenize_and_save(YahooAnswers_train, './dataset/', 'yahooanswers_train.txt') +# tokenize_and_save(YahooAnswers_test, './dataset/testdata/', 'yahooanswers_test.txt') +# tokenize_and_save(YelpReviewFull_train, './dataset/', 'yelpreviewfull_train.txt') +# tokenize_and_save(YelpReviewFull_test, './dataset/testdata/', 'yelpreviewfull_test.txt') diff --git a/utils/visualization.py b/utils/visualization.py index 10a05a83..b32ba562 100644 --- a/utils/visualization.py +++ b/utils/visualization.py @@ -4,40 +4,57 @@ # @FileName : visualization.py # @Time : Created at 2019-03-19 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import matplotlib.pyplot as plt title_dict = { - 'gen_pre_loss': 'pre_loss', - 'gen_adv_loss': 'g_loss', - 'gen_mana_loss': 'mana_loss', - 'gen_work_loss': 'work_loss', - 'dis_loss': 'd_loss', - 'dis_train_acc': 'train_acc', - 'dis_eval_acc': 'eval_acc', - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'BLEU-3': 'BLEU-3', + "gen_pre_loss": "pre_loss", + "gen_adv_loss": "g_loss", + "gen_mana_loss": "mana_loss", + "gen_work_loss": "work_loss", + "dis_loss": "d_loss", + "dis_train_acc": "train_acc", + "dis_eval_acc": "eval_acc", + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "BLEU-3": "BLEU-3", } -color_list = ['#e74c3c', '#e67e22', '#f1c40f', '#8e44ad', '#2980b9', '#27ae60', '#16a085'] +color_list = [ + "#e74c3c", + "#e67e22", + "#f1c40f", + "#8e44ad", + "#2980b9", + "#27ae60", + "#16a085", +] def plt_data(data, step, title, c_id, savefig=False): x = [i for i in range(step)] plt.plot(x, data, color=color_list[c_id], label=title) if savefig: - plt.savefig('savefig/' + title + '.png') + plt.savefig("savefig/" + title + ".png") def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'pre_loss': [], 'g_loss': [], 'mana_loss': [], 'work_loss': [], - 'd_loss': [], 'train_acc': [], 'eval_acc': [], 'NLL_oracle': [], - 'NLL_gen': [], 'BLEU-3': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = { + "pre_loss": [], + "g_loss": [], + "mana_loss": [], + "work_loss": [], + "d_loss": [], + "train_acc": [], + "eval_acc": [], + "NLL_oracle": [], + "NLL_gen": [], + "BLEU-3": [], + } for line in all_lines: items = line.split() @@ -51,28 +68,33 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - log_file_root = '../log/' +if __name__ == "__main__": + log_file_root = "../log/" # Custom your log files in lists, no more than len(color_list) - log_file_list = ['log_0604_2233', 'log_0605_0120', 'log_0531_1507'] - legend_text = ['SeqGAN', 'LeakGAN', 'RelGAN'] + log_file_list = ["log_0604_2233", "log_0605_0120", "log_0531_1507"] + legend_text = ["SeqGAN", "LeakGAN", "RelGAN"] color_id = 0 - data_name = 'NLL_oracle' + data_name = "NLL_oracle" if_save = False # legend_text = log_file_list - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.clf() plt.title(data_name) all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - plt_data(all_data[title_dict[data_name]], len(all_data[title_dict[data_name]]), - legend_text[idx], color_id, if_save) + plt_data( + all_data[title_dict[data_name]], + len(all_data[title_dict[data_name]]), + legend_text[idx], + color_id, + if_save, + ) color_id += 1 plt.legend() diff --git a/visual/visual_human.py b/visual/visual_human.py index 96bc6c43..73d1f8a8 100644 --- a/visual/visual_human.py +++ b/visual/visual_human.py @@ -25,48 +25,120 @@ bar_width = 0.5 opacity = 1.0 -error_config = {'ecolor': '0'} +error_config = {"ecolor": "0"} -rects1 = ax.bar(0, CSGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#8e44ad', error_kw=error_config, - label='CSGAN') +rects1 = ax.bar( + 0, + CSGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#8e44ad", + error_kw=error_config, + label="CSGAN", +) -rects2 = ax.bar(bar_width, SentiGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#27ae60', error_kw=error_config, - label='SentiGAN') +rects2 = ax.bar( + bar_width, + SentiGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#27ae60", + error_kw=error_config, + label="SentiGAN", +) -rects3 = ax.bar(0 + 2 * bar_width, CatGAN_m, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#d35400', error_kw=error_config, - label='CatGAN ($k=2$)') +rects3 = ax.bar( + 0 + 2 * bar_width, + CatGAN_m, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#d35400", + error_kw=error_config, + label="CatGAN ($k=2$)", +) gap = 1.2 -rects4 = ax.bar(3 * bar_width + gap, SeqGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#fd79a8', error_kw=error_config, - label='SeqGAN') +rects4 = ax.bar( + 3 * bar_width + gap, + SeqGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#fd79a8", + error_kw=error_config, + label="SeqGAN", +) -rects5 = ax.bar(4 * bar_width + gap, RankGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#34495e', error_kw=error_config, - label='RankGAN') +rects5 = ax.bar( + 4 * bar_width + gap, + RankGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#34495e", + error_kw=error_config, + label="RankGAN", +) -rects6 = ax.bar(0 + 5 * bar_width + gap, LeakGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#f1c40f', error_kw=error_config, - label='LeakGAN') -rects7 = ax.bar(6 * bar_width + gap, RelGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#2980b9', error_kw=error_config, - label='RelGAN') -rects8 = ax.bar(7 * bar_width + gap, CatGAN_s, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#c0392b', error_kw=error_config, - label='CatGAN ($k=1$)') +rects6 = ax.bar( + 0 + 5 * bar_width + gap, + LeakGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#f1c40f", + error_kw=error_config, + label="LeakGAN", +) +rects7 = ax.bar( + 6 * bar_width + gap, + RelGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#2980b9", + error_kw=error_config, + label="RelGAN", +) +rects8 = ax.bar( + 7 * bar_width + gap, + CatGAN_s, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#c0392b", + error_kw=error_config, + label="CatGAN ($k=1$)", +) -ax.set_xlabel('Dataset') -ax.set_ylabel('Human Score') +ax.set_xlabel("Dataset") +ax.set_ylabel("Human Score") # ax.set_title('Scores by group and gender') len = ((0 + 3 * bar_width) / 3, 3 * bar_width + gap + 2 * bar_width) ax.set_xticks(len) -ax.set_xticklabels(('AR', 'EN')) +ax.set_xticklabels(("AR", "EN")) ax.legend(bbox_to_anchor=(1, 0), loc=3, borderaxespad=0.2) # plt.legend() fig.tight_layout() -plt.savefig('savefig/human.pdf') +plt.savefig("savefig/human.pdf") plt.show() # plt.savefig('C:/1123.pdf') diff --git a/visual/visual_metric.py b/visual/visual_metric.py index c4ca3eed..08107741 100644 --- a/visual/visual_metric.py +++ b/visual/visual_metric.py @@ -4,13 +4,13 @@ # @FileName : visual_metric.py # @Time : Created at 2019-11-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import matplotlib.pyplot as plt import numpy as np -color_list = ['#2980b9', '#e74c3c', '#1abc9c', '#9b59b6'] +color_list = ["#2980b9", "#e74c3c", "#1abc9c", "#9b59b6"] def plt_x_y_data(x, y, title, c_id): @@ -18,9 +18,9 @@ def plt_x_y_data(x, y, title, c_id): def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": []} for line in all_lines: items = line.split() @@ -34,14 +34,14 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - log_file_root = 'log/' +if __name__ == "__main__": + log_file_root = "log/" # Custom your log files in lists, no more than len(color_list) - log_file_list = ['jsdgan_vanilla_oracle', 'catgan_vanilla_oracle'] - legend_text = ['JSDGAN', 'CatGAN'] + log_file_list = ["jsdgan_vanilla_oracle", "catgan_vanilla_oracle"] + legend_text = ["JSDGAN", "CatGAN"] color_id = 0 - title = 'Synthetic data' + title = "Synthetic data" if_save = True length = 100 @@ -49,19 +49,23 @@ def get_log_data(filename): plt.title(title) all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - idxs = np.argsort(-np.array(all_data['NLL_oracle'])) - plt_x_y_data(np.array(all_data['NLL_oracle'])[idxs][:length], np.array(all_data['NLL_div'])[idxs][:length], - legend_text[idx], color_id) + idxs = np.argsort(-np.array(all_data["NLL_oracle"])) + plt_x_y_data( + np.array(all_data["NLL_oracle"])[idxs][:length], + np.array(all_data["NLL_div"])[idxs][:length], + legend_text[idx], + color_id, + ) color_id += 1 plt.legend() # plt.tight_layout() - plt.xlabel(r'${\rm NLL_{\rm oracle}}$') - plt.ylabel(r'${\rm NLL_{\rm div}}$') + plt.xlabel(r"${\rm NLL_{\rm oracle}}$") + plt.ylabel(r"${\rm NLL_{\rm div}}$") if if_save: - plt.savefig('../savefig/synthetic_oracle_div.png') + plt.savefig("../savefig/synthetic_oracle_div.png") plt.show() diff --git a/visual/visual_temp_appendix.py b/visual/visual_temp_appendix.py index bea9d1ae..deb553eb 100644 --- a/visual/visual_temp_appendix.py +++ b/visual/visual_temp_appendix.py @@ -5,39 +5,39 @@ import os title_dict = { - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'NLL_div': 'NLL_div', - 'nll_oracle': 'nll_oracle', - 'nll_div': 'nll_div', - 'temp': 'temp', + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "NLL_div": "NLL_div", + "nll_oracle": "nll_oracle", + "nll_div": "nll_div", + "temp": "temp", } -color_list = ['#2980b9', '#e74c3c', '#1abc9c', '#9b59b6'] -ls_list = ['--', '-'] +color_list = ["#2980b9", "#e74c3c", "#1abc9c", "#9b59b6"] +ls_list = ["--", "-"] marker_list = [None, None] def plt_data(data, length, title, c_id, ls, marker, start=0): x = np.arange(start, start + length, 1) - data = data[start:start + length] + data = data[start : start + length] plt.plot(x, data, color=color_list[c_id], label=title, lw=1.0, ls=ls, marker=marker) if length < 100: plt.xticks(np.arange(start, start + length + 1, 5)) def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': [], 'temp': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": [], "temp": []} for line in all_lines: items = line.split() try: for key in data_dict.keys(): - if '>>>' not in items and key in items: + if ">>>" not in items and key in items: target = items[items.index(key) + 2] - if ',' in target: + if "," in target: target = target[:-1] data_dict[key].append(float(target)) except: @@ -46,13 +46,13 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - os.chdir('..') - log_file_root = 'savefig/figure_log/' - log_file_list = ['exp_temp5', 'evo_temp5_nll'] - legend_text = ['Exponential temperature', 'Evolutionary temperature'] +if __name__ == "__main__": + os.chdir("..") + log_file_root = "savefig/figure_log/" + log_file_list = ["exp_temp5", "evo_temp5_nll"] + legend_text = ["Exponential temperature", "Evolutionary temperature"] - data_name = 'temp' + data_name = "temp" if_save = True color_id = 0 all_data_list = [] @@ -62,23 +62,30 @@ def get_log_data(filename): plt.clf() if length < 100: plt.figure(figsize=(4, 3)) - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.xticks(fontsize=7) plt.yticks(fontsize=7) for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - plt_data(all_data[title_dict[data_name]], length, legend_text[idx], color_id, start=start, ls=ls_list[idx], - marker=marker_list[idx]) + plt_data( + all_data[title_dict[data_name]], + length, + legend_text[idx], + color_id, + start=start, + ls=ls_list[idx], + marker=marker_list[idx], + ) color_id += 1 if length > 100: - plt.legend(prop={'size': 7}) + plt.legend(prop={"size": 7}) plt.xlabel(r"training iterations", fontsize=7) plt.ylabel(r"temperature", fontsize=7) plt.tight_layout() if if_save: - plt.savefig('savefig/temp_curve_{}.pdf'.format(length)) + plt.savefig("savefig/temp_curve_{}.pdf".format(length)) plt.show() diff --git a/visual/visual_temp_compare.py b/visual/visual_temp_compare.py index a96ca64a..83b2171a 100644 --- a/visual/visual_temp_compare.py +++ b/visual/visual_temp_compare.py @@ -4,15 +4,15 @@ import numpy as np title_dict = { - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'NLL_div': 'NLL_div', - 'nll_oracle': 'nll_oracle', - 'nll_div': 'nll_div', - 'temp': 'temp', + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "NLL_div": "NLL_div", + "nll_oracle": "nll_oracle", + "nll_div": "nll_div", + "temp": "temp", } -color_list = ['#e74c3c', '#f1c40f', '#1abc9c', '#9b59b6'] +color_list = ["#e74c3c", "#f1c40f", "#1abc9c", "#9b59b6"] def plt_data(data, title, c_id): @@ -25,17 +25,17 @@ def plt_data(data, title, c_id): def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': [], 'temp': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": [], "temp": []} for line in all_lines: items = line.split() try: for key in data_dict.keys(): - if '>>>' not in items and key in items: + if ">>>" not in items and key in items: target = items[items.index(key) + 2] - if ',' in target: + if "," in target: target = target[:-1] data_dict[key].append(float(target)) except: @@ -44,40 +44,49 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': +if __name__ == "__main__": # log_file_root = '../log/' - log_file_root = 'savefig/figure_log/' - log_file_list = ['catgan_temp1_final', 'catgan_temp5_final', 'relgan_temp1_final', 'relgan_temp5_final'] - legend_text = [r'CatGAN ($\tau_{\rm{tar}}$=1)', r'CatGAN ($\tau_{\rm{tar}}$=5)', r'RelGAN ($\tau_{\rm{tar}}$=1)', - r'RelGAN ($\tau_{\rm{tar}}$=5)'] - - data_name_list = ['NLL_oracle', 'NLL_div'] + log_file_root = "savefig/figure_log/" + log_file_list = [ + "catgan_temp1_final", + "catgan_temp5_final", + "relgan_temp1_final", + "relgan_temp5_final", + ] + legend_text = [ + r"CatGAN ($\tau_{\rm{tar}}$=1)", + r"CatGAN ($\tau_{\rm{tar}}$=5)", + r"RelGAN ($\tau_{\rm{tar}}$=1)", + r"RelGAN ($\tau_{\rm{tar}}$=5)", + ] + + data_name_list = ["NLL_oracle", "NLL_div"] if_save = True plt.clf() plt.figure(figsize=(8, 3.5)) for cur_id, data_name in enumerate(data_name_list): - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.subplot(12 * 10 + cur_id + 1) if cur_id == 0: # plt.title(r"$\rm{NLL}_{\rm{oracle}}$") plt.ylabel(r"$\rm{NLL}_{\rm{oracle}}$", fontsize=12) - plt.plot([150, 150], [8.3, 9.4], 'k--') + plt.plot([150, 150], [8.3, 9.4], "k--") else: # plt.title(r"$\rm{NLL}_{\rm{div}}$") plt.ylabel(r"$\rm{NLL}_{\rm{div}}$", fontsize=12) - plt.plot([150, 150], [3.3, 5], 'k--') + plt.plot([150, 150], [3.3, 5], "k--") plt.xlabel("training iterations", fontsize=12) color_id = 0 all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - if 'catgan' in log_file or 'relgan' in log_file: + if "catgan" in log_file or "relgan" in log_file: temp = all_data[title_dict[data_name]] last = list(np.array(temp)[range(15, 108, 2)]) res = temp[:15] + last @@ -89,5 +98,5 @@ def get_log_data(filename): plt.legend() plt.tight_layout() if if_save: - plt.savefig('savefig/temp_figure.pdf') + plt.savefig("savefig/temp_figure.pdf") plt.show()