From cbf01de2aadd9a5e27267e9c133030ab608af59e Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 26 Jan 2023 11:00:53 +0000 Subject: [PATCH 01/81] error with inplace modification --- models/relational_rnn_general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/relational_rnn_general.py b/models/relational_rnn_general.py index e2d5178c..b4b02b21 100644 --- a/models/relational_rnn_general.py +++ b/models/relational_rnn_general.py @@ -171,7 +171,7 @@ def multihead_attention(self, memory): q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1) # scale q with d_k, the dimensionality of the key vectors - q *= (self.key_size ** -0.5) + q = q * (self.key_size ** -0.5) # make it [B, H, N, N] dot_product = torch.matmul(q, k.permute(0, 1, 3, 2)) From 8fa838ea06cd6373b690ef66a432627f8a40bea6 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 09:47:27 +0000 Subject: [PATCH 02/81] implementation WIP --- config.py | 2 + instructor/real_data/fixem_instructor.py | 77 +++++++ main.py | 2 + models/FixemGAN_D.py | 88 +++++++ models/FixemGAN_G.py | 134 +++++++++++ models/Oracle.py | 2 +- run/run_fixem.py | 171 ++++++++++++++ utils/create_embeddings.py | 22 ++ utils/data_loader.py | 53 ++++- utils/nn_helpers.py | 279 +++++++++++++++++++++++ utils/text_process.py | 26 ++- 11 files changed, 853 insertions(+), 3 deletions(-) create mode 100644 instructor/real_data/fixem_instructor.py create mode 100644 models/FixemGAN_D.py create mode 100644 models/FixemGAN_G.py create mode 100644 run/run_fixem.py create mode 100644 utils/create_embeddings.py create mode 100644 utils/nn_helpers.py diff --git a/config.py b/config.py index 63b9d035..14a7f4df 100644 --- a/config.py +++ b/config.py @@ -311,6 +311,8 @@ def init_param(opt): cat_train_data = 'dataset/' + dataset + '_cat{}.txt' cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' + texts_data = 'dataset/' # do not include testdata + if max_seq_len == 40: oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt' multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt' diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py new file mode 100644 index 00000000..1f0e6a45 --- /dev/null +++ b/instructor/real_data/fixem_instructor.py @@ -0,0 +1,77 @@ +import os +import torch +import torchtext + +from instructor.real_data.instructor import BasicInstructor +from utils.text_process import tokenize +from utils.data_loader import DataSupplier + + +# TO DO: +# 1. train embedding if not exists (if oracle, then always retrain ) +# 2. create data generator (categorical and non categorical) (based on given dataset) +# 3. create disc and gen +# 4. train epochs and each 10 epochs print metrics +# 5. show metrics +# 6. save? or save each 10 epochs + +# chack target real/fake to be right (Uniform or const) + + +class FixemGANInstructor(BasicInstructor): + def __init__(self, cfg): + super(FixemGANInstructor, self).__init__(cfg) + # check if embeddings already exist for current oracle + if os.path.exists(f'oracle path/{cfg.embedding_file_name}'): + # train embedding with oracle + w2v = load_embedding(f'oracle path/{cfg.embedding_file_name}') + + print(self.train_data) + + print(self.train_data_list) + + # data_generator = DataSupplier + + DataLoader( + dataset=GANDataset(), + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=True + ) + + + # try: + # self.train_data = GenDataIter(cfg.train_data) + # self.test_data = GenDataIter(cfg.test_data, if_test_data=True) + # except: + # pass + + # try: + # self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)] + # self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in + # range(cfg.k_label)] + # self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in + # range(cfg.k_label)] + + # self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)] + # self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)] + # except: + # pass + + + + def one_more_batch_for_generator( + self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 + ): + generator_acc = min(leave_in_generator_max, generator_acc) + generator_acc = max(leave_in_generator_min, generator_acc) + if random.random() > generator_acc: + return True + return False + + def write_txt_file(self, source, save_path, save_filename): + with open(os.path.join(save_path, save_filename), 'w') as f: + for _, text in source: + line = ' '.join(tokenize(text)) + f.write(line) + f.write('\n') diff --git a/main.py b/main.py index 3ad65402..3c7aef72 100644 --- a/main.py +++ b/main.py @@ -136,6 +136,7 @@ def program_config(parser): from instructor.real_data.catgan_instructor import CatGANInstructor from instructor.real_data.dgsan_instructor import DGSANInstructor from instructor.real_data.cot_instructor import CoTInstructor + from instructor.real_data.fixem_instructor import FixemGANInstructor else: from instructor.oracle_data.seqgan_instructor import SeqGANInstructor @@ -162,6 +163,7 @@ def program_config(parser): 'catgan': CatGANInstructor, 'dgsan': DGSANInstructor, 'cot': CoTInstructor, + 'fixemgan': FixemGANInstructor } inst = instruction_dict[cfg.run_model](opt) diff --git a/models/FixemGAN_D.py b/models/FixemGAN_D.py new file mode 100644 index 00000000..f713e206 --- /dev/null +++ b/models/FixemGAN_D.py @@ -0,0 +1,88 @@ +import torch.nn as nn + +from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten + +@dataclass +class DiscriminatorParameters: + complexity: int = 512 + alpha: float = 0.2 + drop_rate: float = 0.0 + transformer: bool = False + transformer_layers: int = 6 + + +class Discriminator(nn.Module): + def __init__(self, parameters: DiscriminatorParameters, verbose=False): + super(Discriminator, self).__init__() + complexity = parameters.complexity + alpha = parameters.alpha + drop_rate = parameters.drop_rate + include_transformer = parameters.transformer + + self.main = nn.Sequential( + # 1 layer + MyConvLayer(EMBEDDING_SIZE, complexity, alpha=alpha, drop_rate=drop_rate), + # 2 layer + MyConvLayer( + complexity, + complexity, + alpha=alpha, + drop_rate=drop_rate, + ), + # 3 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # MyLSTMLayer(complexity, complexity//2), + # 4 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # 5 layer + + MyTransformerEncoderLayer( + d_model=complexity, + n_layers=parameters.transformer_layers, + ) + if include_transformer + else Dummy(), + + # 6 layer + MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), + # MyLSTMLayer(complexity, complexity//2), + # 7 layer + MyConvLayer( + complexity, + complexity, + stride=2, + padding=1, + alpha=alpha, + drop_rate=drop_rate, + ), + + MyConvLayer( + complexity, + complexity, + stride=2, + padding=1, + alpha=alpha, + drop_rate=drop_rate, + ), + # 8 layer + Flatten(), + nn.Linear(complexity * TARGET_LEN // 2 // 2, complexity), + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + + self.real_fake = nn.Sequential( + nn.Linear(complexity, 1), + ) + self.labels = nn.Sequential( + nn.Linear(complexity, DEPTH), + ) + self.optimizer = get_optimizer() + + @property + def nb_of_parameters(self): + return number_of_parameters(self.parameters()) + + def forward(self, x): + x = self.main(x) + return self.real_fake(x), self.labels(x) diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py new file mode 100644 index 00000000..b0c59306 --- /dev/null +++ b/models/FixemGAN_G.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass + +import torch.nn as nn +from utils.nn_helpers import get_optimizer, Concatenate, Reshape, MyConvLayerNorm, MyConvTransposeLayer, PositionalEncoding, MyLSTMLayerNorm + + +@dataclass +class GeneratorParameters: + complexity: int = 512 + concatenate_pe: bool = False + leacky_ReLU_alpha: float = 0.2 + batch_norm: bool = True + transformer: bool = False + lstm: bool = True + transformer_layers: int = 3 + + +class Generator(nn.Module): + def __init__(self, parameters: GeneratorParameters, embedding_size: int, verbose=False): + super(Generator, self).__init__() + complexity = parameters.complexity + alpha = parameters.leacky_ReLU_alpha + added_dim_pe = parameters.complexity if parameters.concatenate_pe else 0 + include_batch_norm = parameters.batch_norm + include_transformer = parameters.transformer + include_lstm = parameters.lstm + + self.embedding_size = embedding_size + self.real_fake_criterion = nn.BCELoss() + self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + + self.main = nn.Sequential( + # 1 layer + Concatenate(1), + nn.Linear(NOISE_SIZE + DEPTH, TARGET_LEN // 2 // 2 * complexity), + nn.BatchNorm1d(TARGET_LEN // 2 // 2 * complexity), + nn.LeakyReLU(alpha), + Reshape(complexity, TARGET_LEN // 2 // 2), + # 2 layer + MyConvLayerNorm(complexity, complexity, alpha=alpha), + # 3 layer + MyConvTransposeLayer( + complexity, + complexity, + stride=2, + output_padding=1, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 4 layer + MyConvTransposeLayer( + complexity, + complexity, + stride=2, + output_padding=1, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # adding/concatenating positional encoding + PositionalEncoding( + dim_pe=parameters.complexity, + max_len=TARGET_LEN, + concatenate_pe=parameters.concatenate_pe, + ), + # 5 layer + MyConvLayerNorm( + complexity + added_dim_pe, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # adding/concatenating positional encoding + PositionalEncoding( + dim_pe=complexity, + max_len=TARGET_LEN, + concatenate_pe=parameters.concatenate_pe, + ), + # 6 layer + MyTransformerEncoderLayer( + d_model=complexity + added_dim_pe, + n_layers=parameters.transformer_layers, + ) + if include_transformer + else Dummy(), + # 7 layer + MyConvTransposeLayer( + complexity + added_dim_pe, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 8 layer + MyLSTMLayerNorm( + complexity, + complexity//2, + ) if include_lstm else Dummy(), + + # 9 layer + MyConvTransposeLayer( + complexity, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 10 layer + MyLSTMLayerNorm( + complexity, + complexity//2, + ) if include_lstm else Dummy(),, + # 11 layer + MyConvTransposeLayer( + complexity, + complexity, + alpha=alpha, + include_batch_norm=include_batch_norm, + ), + # 12 layer + nn.Conv1d( + complexity, + EMBEDDING_SIZE, + kernel_size=1, + stride=1, + padding=0, + ), + ) + self.optimizer = get_optimizer() + self.to(device) + if verbose: + print("total parameters:", number_of_parameters(self.parameters())) + + def forward(self, noise, target_labels): + target_labels = torch.nn.functional.one_hot(target_labels, num_classes=DEPTH) + x = self.main([noise, target_labels]) + return x diff --git a/models/Oracle.py b/models/Oracle.py index 0ea46c8e..89c4c0e5 100644 --- a/models/Oracle.py +++ b/models/Oracle.py @@ -4,7 +4,7 @@ # @FileName : Oracle.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from models.generator import LSTMGenerator diff --git a/run/run_fixem.py b/run/run_fixem.py new file mode 100644 index 00000000..db443b6a --- /dev/null +++ b/run/run_fixem.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# @Author : William +# @Project : TextGAN-william +# @FileName : run_catgan.py +# @Time : Created at 2019-08-04 +# @Blog : http://zhiweil.ml/ +# @Description : +# Copyrights (C) 2018. All Rights Reserved. +import sys +from subprocess import call + +import os + +# Job id and gpu_id +if len(sys.argv) > 2: + job_id = int(sys.argv[1]) + gpu_id = str(sys.argv[2]) + print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) +elif len(sys.argv) > 1: + job_id = int(sys.argv[1]) + gpu_id = 0 + print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) +else: + job_id = 0 + gpu_id = 0 + print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + +# Executables +executable = 'python' +rootdir = '../' +scriptname = 'main.py' + +# ===Program=== +# CatGAN: Catgory text generation model +# EvoGAN: General text generation model +if_test = int(False) +run_model = ['fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'cat_fixemgan', 'cat_fixemgan'] +k_label = 2 +CUDA = int(True) +batch_size = 32 + + +# ===Oracle or Real=== +if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] +dataset = ['oracle', 'mr15', 'amazon_app_book', 'oracle', 'image_coco', 'emnlp_news'] +vocab_size = [5000, 0, 0, 5000, 0, 0] +target_len = [20, 20, 40, 20, 16, 52] + +# ===CatGAN Param=== +n_parent = 1 +loss_type = 'fixem' +mu_type = 'ragan rsgan' +eval_type = 'Ra' +temp_adpt = 'exp' +d_out_mean = int(False) +embedding_size = [100, 512, 512, 100, 512, 512] +embedding_filename = 'w2v_{}.model'.format(embedding_size[job_id]) +w2v_window = 5 +w2v_min_count = 50 +w2v_workers = 1 + +# === Basic Param === +data_shuffle = int(False) +model_type = 'fixem' +gen_init = 'truncated_normal' +dis_init = 'uniform' +samples_num = 10000 +batch_size = 64 +max_seq_len = 20 +gen_lr = 0.01 +gen_adv_lr = 1e-4 +dis_lr = 1e-4 +pre_log_step = 10 +adv_log_step = 20 + +# ===Generator=== +ADV_g_step = 1 +gen_embed_dim = 32 +gen_hidden_dim = 32 +mem_slots = 1 +num_heads = 2 +head_size = [512, 512, 512, 256, 256, 256] + +# ===Discriminator=== +ADV_d_step = 3 +dis_embed_dim = 64 +dis_hidden_dim = 64 +num_rep = 64 + +# ===Metrics=== +use_nll_oracle = int(True) +use_nll_gen = int(True) +use_nll_div = int(True) +use_bleu = int(True) +use_self_bleu = int(True) +use_clas_acc = int(True) +use_ppl = int(False) + +args = [ + # Program + '--if_test', if_test, + '--run_model', run_model[job_id], + '--k_label', k_label, + '--cuda', CUDA, + # '--device', gpu_id, # comment for auto GPU + '--ora_pretrain', ora_pretrain, + '--gen_pretrain', gen_pretrain, + '--dis_pretrain', dis_pretrain, + '--mle_epoch', MLE_train_epoch, + '--clas_pre_epoch', clas_pre_epoch, + '--adv_epoch', ADV_train_epoch, + '--tips', tips.format(run_model[job_id]), + + # Oracle or Real + '--if_real_data', if_real_data[job_id], + '--dataset', dataset[job_id], + '--vocab_size', vocab_size[job_id], + + # CatGAN Param + '--n_parent', n_parent, + '--loss_type', loss_type, + '--mu_type', mu_type, + '--eval_type', eval_type, + '--temp_adpt', temp_adpt, + '--temperature', temperature[job_id], + '--d_out_mean', d_out_mean, + '--lambda_fq', lambda_fq, + '--lambda_fd', lambda_fd, + '--eval_b_num', eval_b_num, + + # Basic Param + '--shuffle', data_shuffle, + '--model_type', model_type, + '--gen_init', gen_init, + '--dis_init', dis_init, + '--samples_num', samples_num, + '--batch_size', batch_size, + '--max_seq_len', max_seq_len, + '--gen_lr', gen_lr, + '--gen_adv_lr', gen_adv_lr, + '--dis_lr', dis_lr, + '--pre_log_step', pre_log_step, + '--adv_log_step', adv_log_step, + + # Generator + '--adv_g_step', ADV_g_step, + '--gen_embed_dim', gen_embed_dim, + '--gen_hidden_dim', gen_hidden_dim, + '--mem_slots', mem_slots, + '--num_heads', num_heads, + '--head_size', head_size[job_id], + + # Discriminator + '--adv_d_step', ADV_d_step, + '--dis_embed_dim', dis_embed_dim, + '--dis_hidden_dim', dis_hidden_dim, + '--num_rep', num_rep, + + # Metrics + '--use_nll_oracle', use_nll_oracle, + '--use_nll_gen', use_nll_gen, + '--use_nll_div', use_nll_div, + '--use_bleu', use_bleu, + '--use_self_bleu', use_self_bleu, + '--use_clas_acc', use_clas_acc, + '--use_ppl', use_ppl, +] + +args = list(map(str, args)) +my_env = os.environ.copy() +call([executable, scriptname] + args, env=my_env, cwd=rootdir) diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py new file mode 100644 index 00000000..810e41e3 --- /dev/null +++ b/utils/create_embeddings.py @@ -0,0 +1,22 @@ +from gensim.models import Word2Vec + +import config as cfg +from utils.text_process import get_tokenized_from_file + + +class EmbeddingsTrainer: + def __init__(self, *files, size=512, save_filename=cfg.word2vec_model_name): + self.files = files + self.size = size + self.save_filename = save_filename + + def make_embeddings(self, verbose=True): + tokenized = get_tokenized_from_file(self.files) + W2V = Word2Vec( + sentences=tokenized, + size=self.size, + window=cfg.w2v_window, + min_count=cfg.w2v_min_count, + workers=cfg.w2v_workers, + ) + W2V.save(self.save_filename) diff --git a/utils/data_loader.py b/utils/data_loader.py index 3d7ed791..62e7e9f3 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -4,7 +4,7 @@ # @FileName : data_loader.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import random @@ -125,3 +125,54 @@ def prepare(self, pos_samples, neg_samples, gpu=False): if gpu: return inp.cuda(), target.cuda() return inp, target + + +class DataSupplier: + def __init__(self, tokenized, labels, verbose, batch_size, batches_per_epoch): + self.verbose=verbose + + self.labels = torch.tensor(labels, dtype=int) + self.vectors = [vectorize_sentence(txt, padding_token = PADDING) for txt in tokenized] + self.vectors = np.stack(vectors, axis=0) + self.vectors = torch.tensor(vectors, dtype=torch.float32) + + self.batches_per_epoch = batches_per_epoch + self.batch_size = batch_size + + self.texts = set(" ".join(text[-TARGET_LEN:]) for text in texts) + if self.verbose: + print('texts examples', [txt for txt in self.texts][:3]) + + + def __iter__(self): + batch_iterator = trange(self.batches_per_epoch) if self.verbose else range(self.batches_per_epoch) + permutation = torch.randperm(len(self)) + self.vectors = self.vectors[permutation] + self.labels = self.labels[permutation] + + # permutation = torch.randint(low=0, high=len(self), size=(self.batch_size,)) + # yield self.labels[permutation].to(device), self.vectors[permutation].to(device) + + for _ in batch_iterator: + if len(self) < self.batch_size: + # we need to repat self vectors several times + repeats = self.batch_size // len(self.vectors) + yield self.labels.repeat(repeats).to(device), self.vectors.repeat(repeats, 1, 1).to(device) + + else: + index = 0 + index += self.batch_size + if index > len(self): + # concatenating beginning of self.vectors + yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])).to(device), + torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).to(device)) + index = index % len(self) + else: + yield self.labels[index - self.batch_size: index].to(device), self.vectors[index - self.batch_size: index].to(device) + + + def __len__(self): + return len(self.vectors) + + def is_this_message_in_dataset(self, text): + return text in self.texts diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py new file mode 100644 index 00000000..4e9123a1 --- /dev/null +++ b/utils/nn_helpers.py @@ -0,0 +1,279 @@ +import torch +import torch.nn as nn +from torch import Tensor +import torch.optim as optim +from torch.distributions.uniform import Uniform + +import math + +import matplotlib.pyplot as plt + + +def create_noise(sample_size, noise_size=NOISE_SIZE): + return ( + torch.randn(sample_size, noise_size).to(device), + torch.randint(0, DEPTH, (sample_size,)).sort().values.to(device), + ) + + +def multiply_shape(shape): + if len(shape) == 1: + return shape[0] + return shape[0] * multiply_shape(shape[1:]) + + +def number_of_parameters(parameters): + nb_of_vars = 0 + for parameter in parameters: + nb_of_vars += multiply_shape(tuple(parameter.shape)) + return nb_of_vars + + +def get_optimizer(parameters, lr=0.0001, betas=(0.5, 0.999)): + return optim.Adam(parameters, lr=lr, betas=betas) + + +class PositionalEncoding(nn.Module): + def __init__(self, dim_pe: int, max_len: int = TARGET_LEN, concatenate_pe=False): + super().__init__() + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, dim_pe, 2) * (-math.log(dim_pe // 4) / dim_pe) + ) + pe = torch.zeros(max_len, 1, dim_pe) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + pe = torch.transpose(pe, 0, 1) + pe = torch.transpose(pe, 1, 2) + # plt.imshow(pe[0], cmap="hot", interpolation="nearest") + # plt.show() + self.register_buffer("pe", pe) + self.concatenate_pe = concatenate_pe + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + pe = self.pe.repeat(x.size(0), 1, 1) + # input (N,C,L) - N bathc size, C - channels, L - length + return torch.cat((x, pe), 1) if self.concatenate_pe else x + pe + + +class Reshape(nn.Module): + def __init__(self, *out_shape): + super(Reshape, self).__init__() + self.out_shape = out_shape + + def forward(self, input_batch): + """Turch batched flat vector to out_shape""" + return torch.reshape(input_batch, (-1, *self.out_shape)) + + +class Concatenate(nn.Module): + def __init__(self, dim): + super(Concatenate, self).__init__() + self.dim = dim + + def forward(self, input_batch): + return torch.cat(input_batch, dim=1) + + +class Dummy(nn.Module): + """For shadowing some layers.""" + + def __init__(self): + super(Dummy, self).__init__() + + def forward(self, x): + return x + + +class MyTransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead=1, n_layers=1): + super(MyTransformerEncoderLayer, self).__init__() + self.tranformer_layers = nn.Sequential( + *tuple( + nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, batch_first=True + ) + for _ in range(n_layers) + ) + ) + + def forward(self, x): + x = self.tranformer_layers(torch.transpose(x, 1, 2)) + return torch.transpose(x, 1, 2) + + +class MyConvTransposeLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + output_padding=0, + alpha=0.2, + include_batch_norm=True, + padding=1, + kernel_size=3, + ): + super(MyConvTransposeLayer, self).__init__() + self.conv_layer = nn.Sequential( + nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + nn.LeakyReLU(alpha), + nn.Conv1d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ), + nn.BatchNorm1d(out_channels) if include_batch_norm else Dummy(), + nn.LeakyReLU(alpha), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyConvLayerNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + output_padding=0, + alpha=0.2, + include_batch_norm=True, + padding=1, + kernel_size=3, + ): + super(MyConvLayerNorm, self).__init__() + self.conv_layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + nn.BatchNorm1d(out_channels) if include_batch_norm else Dummy(), + nn.LeakyReLU(alpha), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyLSTMLayerNorm(nn.Module): + def __init__( + self, + in_channels, + out_channels, + alpha=0.2, + ): + super(MyLSTMLayerNorm, self).__init__() + self.lstm = nn.LSTM( + batch_first=True, + bidirectional=True, + input_size=in_channels, + hidden_size=out_channels, + ) + self.layers = nn.Sequential( + nn.BatchNorm1d(2 * out_channels), + nn.LeakyReLU(alpha), + ) + def forward(self, x): + x = torch.transpose(x, 1, 2) + x, (hn, cn) = self.lstm(x) + x = torch.transpose(x, 1,2) + x = self.layers(x) + return x + + +class Flatten(nn.Module): + def __init__(self): + super(Flatten, self).__init__() + self.start_dim = 1 + + def forward(self, input_tensor): + return torch.flatten(input_tensor, start_dim=1) + + +class MyConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding="same", + alpha=0.2, + drop_rate=0.2, + ): + super(MyConvLayer, self).__init__() + self.conv_layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class MyLSTMLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + alpha=0.2, + drop_rate=0.2, + ): + super(MyLSTMLayer, self).__init__() + self.lstm = nn.LSTM( + batch_first=True, + bidirectional=True, + input_size=in_channels, + hidden_size=out_channels, + ) + self.layers = nn.Sequential( + nn.LeakyReLU(alpha), + nn.Dropout(drop_rate), + ) + def forward(self, x): + x = torch.transpose(x, 1, 2) + x, (hn, cn) = self.lstm(x) + x = torch.transpose(x, 1,2) + x = self.layers(x) + return x + + + +def DiversityLoss(): + cs2 = torch.nn.CosineSimilarity(dim=2) + def cos_sim_loss(generated): + batch_size = generated.shape[0] + generated = generated.repeat(batch_size, 1, 1, 1) + generatedTranspose = torch.transpose(generated, 0, 1) + loss = cs2(generated, generatedTranspose) + ind = np.diag_indices(loss.shape[0]) + loss[ind[0], ind[1], :] = 0 # set 0 to similarity of message to itself + loss = loss.mean(axis=2).max(axis=0).values.mean() + return loss + return cos_sim_loss diff --git a/utils/text_process.py b/utils/text_process.py index 7bdb9daf..d4f3342a 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -4,7 +4,7 @@ # @FileName : text_process.py # @Time : Created at 2019-05-14 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import nltk @@ -344,6 +344,30 @@ def build_embedding_matrix(dataset): torch.save(embedding_matrix, embed_filename) return embedding_matrix +def pad_sequences( + sequence, target_len: int = 52, embedding_size: int = 300, padding_token = None +) -> np.array: + sequence = np.array(sequence) + current_length = sequence.shape[0] + + if current_length >= target_len: + return sequence[-target_len:] + + padding = np.repeat(np.array([w2v.wv[padding_token]]), target_len - current_length,axis=0) if padding_token else np.zeros((target_len - current_length, embedding_size)) + return np.concatenate((padding, sequence), axis=0) + + +def vectorize_sentence(tokens, target_len: int = 52, embedding_size: int = 300, padding_token=None): + tokens = tokenizer_func(text) if type(text) == str else text + target_sentence = pad_sequences( + [w2v.wv[token] for token in tokens], + target_len=target_len, + embedding_size=embedding_size, + padding_token=padding_token, + ) + target_sentence = target_sentence.T # required for pytorch + return target_sentence + if __name__ == '__main__': os.chdir('../') From d39f5162d60021303b57e1dd96b2dfa88ff6bf76 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 13:30:28 +0000 Subject: [PATCH 03/81] WIP --- config.py | 7 +- instructor/real_data/fixem_instructor.py | 152 ++++++++++++++++++----- main.py | 6 + models/FixemGAN_D.py | 22 ++-- models/FixemGAN_G.py | 52 ++++---- run/run_fixem.py | 11 +- utils/create_embeddings.py | 30 +++-- utils/data_loader.py | 38 +++--- utils/gan_loss.py | 28 ++++- utils/nn_helpers.py | 2 +- utils/text_process.py | 63 +++++++++- 11 files changed, 300 insertions(+), 111 deletions(-) diff --git a/config.py b/config.py index 14a7f4df..11558eab 100644 --- a/config.py +++ b/config.py @@ -268,6 +268,11 @@ def init_param(opt): dis_hidden_dim = opt.dis_hidden_dim num_rep = opt.num_rep + w2v_embedding_size = opt.embeddging_size + w2v_window = opt.w2v_window + w2v_min_count = opt.w2v_min_count + w2v_workers = opt.w2v_workers + use_nll_oracle = True if opt.use_nll_oracle == 1 else False use_nll_gen = True if opt.use_nll_gen == 1 else False use_nll_div = True if opt.use_nll_div == 1 else False @@ -324,7 +329,7 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - + pretrain_embeddgin_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 1f0e6a45..315417ba 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -1,10 +1,21 @@ import os + +from pathlib import Path +import numpy as np import torch +from torch.utils.data import Dataset, DataLoader import torchtext +from tqdm import trange + + from instructor.real_data.instructor import BasicInstructor -from utils.text_process import tokenize -from utils.data_loader import DataSupplier +from utils.text_process import text_file_iterator +from utils.data_loader import DataSupplier, GANDataset +from utils.nn_helpers import create_noise, number_of_parameters +from utils.create_embedding import EmbeddingsTrainer, load_embedding +from models.FixemGAN_G import Generator +from models.FixemGAN_D import Discriminator # TO DO: @@ -22,42 +33,99 @@ class FixemGANInstructor(BasicInstructor): def __init__(self, cfg): super(FixemGANInstructor, self).__init__(cfg) # check if embeddings already exist for current oracle - if os.path.exists(f'oracle path/{cfg.embedding_file_name}'): - # train embedding with oracle - w2v = load_embedding(f'oracle path/{cfg.embedding_file_name}') + if not os.path.exists(cfg.pretrain_embeddgin_path): + # train embedding on available dataset or oracle + sources = list(Path(texts_data).glob('*.txt')) + EmbeddingsTrainer(sources, cfg.pretrain_embeddgin_path).make_embeddings() + + w2v = load_embedding(cfg.pretrain_embeddgin_path) + + if cfg.run_model == 'fixemgan': + labels, train_data = zip(*[(0, line) for line in text_file_iterator(cfg.train_data)]) + + if cfg.run_model == 'cat_fixemgan': + labels, train_data = zip( + *chain( + *[[(i, line) for line in text_file_iterator(cfg.cat_train_data.format(i))] + for i in range(cfg.k_label)] + ) + ) + + self.train_data_supplier = DataSupplier(train_data, labels, w2v, True, cfg.batch_size, cfg.batches_per_epoch) + + self.discriminator = Discriminator(cfg.discriminator_complexity) + print( + "discriminator total tranable parameters:", + number_of_parameters(self.discriminator.parameters()) + ) + self.generator = Generator(cfg.generator_complexity, cgf.noise_size, w2v) + print( + "generator total tranable parameters:", + number_of_parameters(self.generator.parameters()) + ) + + self.G_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, ) + self.D_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2) - print(self.train_data) + self.all_metrics = [self.bleu, self.self_bleu] - print(self.train_data_list) + def generator_train_one_batch(self): + self.generator.optimizer.zero_grad() + noise = create_noise(cfg.batch_size, cfg.noise_size) + ones = label_ones(cfg.batch_size) + fakes = self.generator(*noise) - # data_generator = DataSupplier + real_fake_predicts, label_predicts = self.discriminator(fakes) + loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, fakes) + loss.backward() + self.generator.optimizer.step() - DataLoader( - dataset=GANDataset(), - batch_size=self.batch_size, - shuffle=self.shuffle, - drop_last=True + generator_acc = float( + np.array(real_fake_predicts.detach().cpu().numpy() > 0.5, dtype=int).mean() ) + return generator_acc + + def discriminator_train_one_batch(self, real_vector, labels): + # important to have equal batch size for fake and real vectors + this_batch_size = real_vector.shape[0] + + # create input + noise = create_noise(this_batch_size, cfg.noise_size) + fake = self.generator(*noise).detach() + text_input_vectors = torch.cat((real_vector, fake)) + + # optmizer step + discriminator.optimizer.zero_grad() + real_fake_predicts, label_predicts = self.discriminator(text_input_vectors) + loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) + loss.backward() + discriminator.optimizer.step() + + discriminator_acc = torch.cat( + ( + real_fake_predicts.chunk(2)[0] > 0.5, + real_fake_predicts.chunk(2)[1] < 0.5 + ) + ) + return discriminator_acc - # try: - # self.train_data = GenDataIter(cfg.train_data) - # self.test_data = GenDataIter(cfg.test_data, if_test_data=True) - # except: - # pass + def _run(self): + for i in trange(cfg.max_epochs): + for labels, text_vector in self.train_data_supplier: + discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) - # try: - # self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)] - # self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in - # range(cfg.k_label)] - # self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in - # range(cfg.k_label)] + generator_acc = 1 - 2 * (discriminator_acc - 0.5) + # run the generator until generator acc not get high enought + while self.one_more_batch_for_generator(generator_acc): + generator_acc = self.generator_train_one_batch() - # self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)] - # self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)] - # except: - # pass + if cfg.run_model = 'fixemgan': + scores = self.cal_metrics(fmt_str=True) + if cfg.run_model = 'cat_fixemgan': + scores = self.cal_metrics_with_label(fmt_str=True) + print('epoch:', i, scores) def one_more_batch_for_generator( @@ -69,9 +137,25 @@ def one_more_batch_for_generator( return True return False - def write_txt_file(self, source, save_path, save_filename): - with open(os.path.join(save_path, save_filename), 'w') as f: - for _, text in source: - line = ' '.join(tokenize(text)) - f.write(line) - f.write('\n') + + def cal_metrics(self, fmt_str=False): + """ + Calculate metrics + :param fmt_str: if return format string for logging + """ + with torch.no_grad(): + # Prepare data for evaluation + gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens_s = self.generator.sample(200, 200) + + # Reset metrics + self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) + # self.nll_gen.reset(self.gen, self.train_data.loader) + # self.nll_div.reset(self.gen, gen_data.loader) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + # self.ppl.reset(gen_tokens) + + if fmt_str: + return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + else: + return [metric.get_score() for metric in self.all_metrics] diff --git a/main.py b/main.py index 3c7aef72..4f73be19 100644 --- a/main.py +++ b/main.py @@ -90,6 +90,12 @@ def program_config(parser): parser.add_argument('--dis_hidden_dim', default=cfg.dis_hidden_dim, type=int) parser.add_argument('--num_rep', default=cfg.num_rep, type=int) + # W2V embeddings + parser.add_argument('--w2v_embedding_size', default=cfg.w2v_embedding_size, type=int) + parser.add_argument('--w2v_window', default=cfg.w2v_window, type=int) + parser.add_argument('--w2v_min_count', default=cfg.w2v_min_count, type=int) + parser.add_argument('--w2v_workers', default=cfg.w2v_workers, type=int) + # Metrics parser.add_argument('--use_nll_oracle', default=cfg.use_nll_oracle, type=int) parser.add_argument('--use_nll_gen', default=cfg.use_nll_gen, type=int) diff --git a/models/FixemGAN_D.py b/models/FixemGAN_D.py index f713e206..e36d8110 100644 --- a/models/FixemGAN_D.py +++ b/models/FixemGAN_D.py @@ -2,22 +2,14 @@ from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten -@dataclass -class DiscriminatorParameters: - complexity: int = 512 - alpha: float = 0.2 - drop_rate: float = 0.0 - transformer: bool = False - transformer_layers: int = 6 +from models.discriminator import CNNDiscriminator - -class Discriminator(nn.Module): - def __init__(self, parameters: DiscriminatorParameters, verbose=False): +class Discriminator(CNNDiscriminator): + def __init__(self, complexity): super(Discriminator, self).__init__() - complexity = parameters.complexity - alpha = parameters.alpha - drop_rate = parameters.drop_rate - include_transformer = parameters.transformer + alpha = 0.2 + drop_rate = 0.0 + include_transformer = False self.main = nn.Sequential( # 1 layer @@ -78,6 +70,8 @@ def __init__(self, parameters: DiscriminatorParameters, verbose=False): nn.Linear(complexity, DEPTH), ) self.optimizer = get_optimizer() + # maybe it will help! + # self.init_params() @property def nb_of_parameters(self): diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index b0c59306..b3f3054b 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -1,33 +1,24 @@ from dataclasses import dataclass import torch.nn as nn -from utils.nn_helpers import get_optimizer, Concatenate, Reshape, MyConvLayerNorm, MyConvTransposeLayer, PositionalEncoding, MyLSTMLayerNorm +from utils.nn_helpers import get_optimizer, create_noise, Concatenate, Reshape, MyConvLayerNorm, MyConvTransposeLayer, PositionalEncoding, MyLSTMLayerNorm +import config as cfg +from models.generator import LSTMGenerator -@dataclass -class GeneratorParameters: - complexity: int = 512 - concatenate_pe: bool = False - leacky_ReLU_alpha: float = 0.2 - batch_norm: bool = True - transformer: bool = False - lstm: bool = True - transformer_layers: int = 3 -class Generator(nn.Module): - def __init__(self, parameters: GeneratorParameters, embedding_size: int, verbose=False): +class Generator(LSTMGenerator): + def __init__(self, complexity, noise_size, w2v): super(Generator, self).__init__() - complexity = parameters.complexity - alpha = parameters.leacky_ReLU_alpha - added_dim_pe = parameters.complexity if parameters.concatenate_pe else 0 - include_batch_norm = parameters.batch_norm - include_transformer = parameters.transformer - include_lstm = parameters.lstm - + alpha = 0.2 + added_dim_pe = 0 + include_batch_norm = True + include_transformer = False + include_lstm = True + self.noise_size = noise_size + self.w2v = w2v self.embedding_size = embedding_size - self.real_fake_criterion = nn.BCELoss() - self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) self.main = nn.Sequential( # 1 layer @@ -125,10 +116,25 @@ def __init__(self, parameters: GeneratorParameters, embedding_size: int, verbose ) self.optimizer = get_optimizer() self.to(device) - if verbose: - print("total parameters:", number_of_parameters(self.parameters())) def forward(self, noise, target_labels): target_labels = torch.nn.functional.one_hot(target_labels, num_classes=DEPTH) x = self.main([noise, target_labels]) return x + + def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): + noise = create_noise(num_samples, self.noise_size) + fakes = self.forward(*noise) + fakes = fakes.detach().cpu().numpy() + assert len(fakes.shape) == 3 + return [recover_sentence(fake) for fake in fakes] + + def recover_sentence(self, fake): + fake = fake.T + tokens = [] + for token_vector in fake: + token = self.w2v.wv.most_similar([token_vec])[0][0] + if token == cfg.padding_token: + continue + tokens.append(token) + return " ".join(tokens).strip() diff --git a/run/run_fixem.py b/run/run_fixem.py index db443b6a..202e5ffe 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -38,13 +38,20 @@ k_label = 2 CUDA = int(True) batch_size = 32 +noise_size = 1000 +max_epochs = 20 +batches_per_epoch = 200 # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book', 'oracle', 'image_coco', 'emnlp_news'] +dataset = ['mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] +w2v_embedding_size = [100, 100, 100, 100, 100, 100] +w2v_window = 5 +w2v_min_count = 30 +.w2v_workers = 30 vocab_size = [5000, 0, 0, 5000, 0, 0] -target_len = [20, 20, 40, 20, 16, 52] +target_len = [16, 40, 20, 16, 52] # ===CatGAN Param=== n_parent = 1 diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py index 810e41e3..866a65fb 100644 --- a/utils/create_embeddings.py +++ b/utils/create_embeddings.py @@ -1,22 +1,34 @@ from gensim.models import Word2Vec import config as cfg -from utils.text_process import get_tokenized_from_file +from utils.text_process import text_file_iterator, get_tokenized_from_file -class EmbeddingsTrainer: - def __init__(self, *files, size=512, save_filename=cfg.word2vec_model_name): +class MultipleFilesIterator: + def __init__(self, files): self.files = files + + def __iter__(self): + for file in self.files: + yield from [cfg.padding_token] * 5 + text_file_iterator(file) + + +class EmbeddingsTrainer: + def __init__(self, sources, save_filename): + self.sources = sources self.size = size self.save_filename = save_filename - def make_embeddings(self, verbose=True): - tokenized = get_tokenized_from_file(self.files) - W2V = Word2Vec( - sentences=tokenized, - size=self.size, + def make_embeddings(self): + w2v = Word2Vec( + sentences=MultipleFilesIterator(self.sources), + size=cfg.w2v_embedding_size, window=cfg.w2v_window, min_count=cfg.w2v_min_count, workers=cfg.w2v_workers, ) - W2V.save(self.save_filename) + w2v.save(self.save_filename) + + +def load_embedding(path): + return Word2Vec.load(path) diff --git a/utils/data_loader.py b/utils/data_loader.py index 62e7e9f3..90161c9e 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -10,6 +10,7 @@ import random from torch.utils.data import Dataset, DataLoader +import config as cfg from utils.text_process import * @@ -128,11 +129,19 @@ def prepare(self, pos_samples, neg_samples, gpu=False): class DataSupplier: - def __init__(self, tokenized, labels, verbose, batch_size, batches_per_epoch): + def __init__(self, tokenized, labels, w2v, verbose, batch_size, batches_per_epoch): self.verbose=verbose + + labels, tokenized = zip(*[ + (label, tokens) + for label, tokens in zip(labels, tokenized) + if all(token in w2v.wv for token in tokens) + ]) + self.labels = torch.tensor(labels, dtype=int) - self.vectors = [vectorize_sentence(txt, padding_token = PADDING) for txt in tokenized] + + self.vectors = [vectorize_sentence(tokens, w2v, padding_token = cfg.padding_token) for tokens in tokenized] self.vectors = np.stack(vectors, axis=0) self.vectors = torch.tensor(vectors, dtype=torch.float32) @@ -150,25 +159,16 @@ def __iter__(self): self.vectors = self.vectors[permutation] self.labels = self.labels[permutation] - # permutation = torch.randint(low=0, high=len(self), size=(self.batch_size,)) - # yield self.labels[permutation].to(device), self.vectors[permutation].to(device) - for _ in batch_iterator: - if len(self) < self.batch_size: - # we need to repat self vectors several times - repeats = self.batch_size // len(self.vectors) - yield self.labels.repeat(repeats).to(device), self.vectors.repeat(repeats, 1, 1).to(device) - + index = 0 + index += self.batch_size + if index > len(self): + # concatenating beginning of self.vectors + yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])).to(device), + torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).to(device)) + index = index % len(self) else: - index = 0 - index += self.batch_size - if index > len(self): - # concatenating beginning of self.vectors - yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])).to(device), - torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).to(device)) - index = index % len(self) - else: - yield self.labels[index - self.batch_size: index].to(device), self.vectors[index - self.batch_size: index].to(device) + yield self.labels[index - self.batch_size: index].to(device), self.vectors[index - self.batch_size: index].to(device) def __len__(self): diff --git a/utils/gan_loss.py b/utils/gan_loss.py index 32411660..bb373839 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -4,14 +4,14 @@ # @FileName : gan_loss.py # @Time : Created at 2019-07-11 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import config as cfg - +from utils.nn_helpers import DiversityLoss class GANLoss(nn.Module): """Define different GAN Discriminator's objectives. @@ -45,6 +45,10 @@ def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_ self.loss = nn.BCEWithLogitsLoss() elif loss_mode in ['wgan', 'hinge']: self.loss = None + elif loss_mode == 'fixem': + self.real_fake_criterion = nn.BCEWithLogitsLoss() + self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + self.diversity_criterion = DiversityLoss() else: raise NotImplementedError('gan mode %s not implemented' % loss_mode) @@ -138,6 +142,26 @@ def D_loss(self, Dreal, Dfake): return loss_fake + loss_real + + def G_loss_fixem(self, real_fake_predicts, label_predicts, target_labels, fakes): + target_fake = self.get_target_tensor(real_fake_predicts, target_is_real=True) + real_fake_loss = self.real_fake_criterion(real_fake_predicts, target_fake) + labels_loss = self.label_criterion(label_predicts, target_labels) + diversity_loss = self.diversity_criterion(fakes) + loss = real_fake_loss + diversity_loss + loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss + return loss + + def D_loss_fixem(self, real_fake_predicts, label_predicts, target_labels): + target_real = self.get_target_tensor(real_fake_predicts.chunk(2)[0], target_is_real=True) + target_fake = self.get_target_tensor(real_fake_predicts.chunk(2)[1], target_is_real=False) + target_real_fake = torch.cat((target_real, target_fake)) + real_fake_loss = self.real_fake_criterion(real_fake_predicts, target_real_fake) + labels_loss = self.label_criterion(label_predicts, target_labels) + loss = real_fake_loss + loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss + return loss + def __call__(self, Dreal, Dfake): """Calculate loss given Discriminator's output and grount truth labels.""" if self.which_net == 'G': diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py index 4e9123a1..b5327e42 100644 --- a/utils/nn_helpers.py +++ b/utils/nn_helpers.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt -def create_noise(sample_size, noise_size=NOISE_SIZE): +def create_noise(sample_size, noise_size): return ( torch.randn(sample_size, noise_size).to(device), torch.randint(0, DEPTH, (sample_size,)).sort().values.to(device), diff --git a/utils/text_process.py b/utils/text_process.py index d4f3342a..454c772e 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -11,10 +11,17 @@ import numpy as np import os import torch +from tqdm import tqdm import config as cfg +def text_file_iterator(file): + with open(file) as raw: + for line in raw.readlines(): + yield line.strip().split() + + def get_tokenlized(file): """tokenlize the file""" tokenlized = list() @@ -357,16 +364,30 @@ def pad_sequences( return np.concatenate((padding, sequence), axis=0) -def vectorize_sentence(tokens, target_len: int = 52, embedding_size: int = 300, padding_token=None): - tokens = tokenizer_func(text) if type(text) == str else text - target_sentence = pad_sequences( +def vectorize_sentence(tokens, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None): + vectorized = pad_sequences( [w2v.wv[token] for token in tokens], target_len=target_len, embedding_size=embedding_size, padding_token=padding_token, ) - target_sentence = target_sentence.T # required for pytorch - return target_sentence + vectorized = vectorized.T # required for pytorch + return vectorized + +import os +import nltk +nltk.download('punkt') +from tqdm.notebook import tqdm +from pathlib import Path + + +def tokenize_and_save(source, path, filename): + with open(Path(path) / filename, 'w') as f: + for _, line in tqdm(source, desc=filename): + line = line.strip().lower() + line = ' '.join(nltk.tokenize.word_tokenize(line)) + f.write(line) + f.write('\n') if __name__ == '__main__': @@ -374,4 +395,34 @@ def vectorize_sentence(tokens, target_len: int = 52, embedding_size: int = 300, # process_cat_text() # load_test_dict('mr15') # extend_clas_train_data() - pass + + # dataset preprocess and saving + import torchtext + + AGNEWS_train, AGNEWS_test = torchtext.datasets.AG_NEWS( + root="./data", split=("train", "test") + ) + DBpedia_train, DBpedia_test = torchtext.datasets.DBpedia( + root="./data", split=("train", "test") + ) + + WikiText103_train, WikiText103_valid, WikiText103_test = torchtext.datasets.WikiText103( + root="./data", split=("train", "valid", "test") + ) + YahooAnswers_train, YahooAnswers_test = torchtext.datasets.YahooAnswers( + root="./data", split=("train", "test") + ) + YelpReviewFull_train, YelpReviewFull_test = torchtext.datasets.YelpReviewFull( + root="./data", split=("train", "test") + ) + tokenize_and_save(AGNEWS_train, './dataset/', 'agnews_train.txt') + tokenize_and_save(AGNEWS_test, './dataset/testdata/', 'agnews_test.txt') + tokenize_and_save(DBpedia_train, './dataset/', 'dbpedia_train.txt') + tokenize_and_save(DBpedia_test, './dataset/testdata/', 'dbpedia_test.txt') + tokenize_and_save(enumerate(WikiText103_train), './dataset/', 'wikitext103_train.txt') + tokenize_and_save(enumerate(WikiText103_valid), './dataset/', 'wikitext103_valid.txt') + tokenize_and_save(enumerate(WikiText103_test), './dataset/testdata/', 'wikitext103_test.txt') + tokenize_and_save(YahooAnswers_train, './dataset/', 'yahooanswers_train.txt') + tokenize_and_save(YahooAnswers_test, './dataset/testdata/', 'yahooanswers_test.txt') + tokenize_and_save(YelpReviewFull_train, './dataset/', 'yelpreviewfull_train.txt') + tokenize_and_save(YelpReviewFull_test, './dataset/testdata/', 'yelpreviewfull_test.txt') From 5947ccb12cea3d224304cf62d9830880f7fbb55f Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 14:02:53 +0000 Subject: [PATCH 04/81] WIP --- config.py | 13 +++++++--- instructor/real_data/fixem_instructor.py | 4 ++-- models/FixemGAN_D.py | 4 ++-- models/FixemGAN_G.py | 14 +++++------ run/run_fixem.py | 30 ++++++++---------------- utils/data_loader.py | 4 ++-- utils/nn_helpers.py | 6 ++--- 7 files changed, 36 insertions(+), 39 deletions(-) diff --git a/config.py b/config.py index 11558eab..41e1086a 100644 --- a/config.py +++ b/config.py @@ -41,6 +41,12 @@ use_all_real_fake = False use_population = False +# ===Embedding=== +w2v_embedding_size = 100 +w2v_window = 5 +w2v_min_count = 1 +w2v_workers = 1 + # ===Oracle or Real, type=== if_real_data = False # if use real data dataset = 'oracle' # oracle, image_coco, emnlp_news, amazon_app_book, amazon_app_movie, mr15 @@ -198,7 +204,8 @@ def init_param(opt): pretrained_clas_path, n_parent, mu_type, eval_type, d_type, eval_b_num, lambda_fd, d_out_mean, \ lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ - use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl + use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ + w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -268,7 +275,7 @@ def init_param(opt): dis_hidden_dim = opt.dis_hidden_dim num_rep = opt.num_rep - w2v_embedding_size = opt.embeddging_size + w2v_embedding_size = opt.w2v_embeddging_size w2v_window = opt.w2v_window w2v_min_count = opt.w2v_min_count w2v_workers = opt.w2v_workers @@ -329,7 +336,7 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - pretrain_embeddgin_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.embedding_size) + pretrain_embeddgin_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.w2v_embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 315417ba..c61b6fc3 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -71,7 +71,7 @@ def __init__(self, cfg): def generator_train_one_batch(self): self.generator.optimizer.zero_grad() - noise = create_noise(cfg.batch_size, cfg.noise_size) + noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) ones = label_ones(cfg.batch_size) fakes = self.generator(*noise) @@ -90,7 +90,7 @@ def discriminator_train_one_batch(self, real_vector, labels): this_batch_size = real_vector.shape[0] # create input - noise = create_noise(this_batch_size, cfg.noise_size) + noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) fake = self.generator(*noise).detach() text_input_vectors = torch.cat((real_vector, fake)) diff --git a/models/FixemGAN_D.py b/models/FixemGAN_D.py index e36d8110..3351c501 100644 --- a/models/FixemGAN_D.py +++ b/models/FixemGAN_D.py @@ -58,7 +58,7 @@ def __init__(self, complexity): ), # 8 layer Flatten(), - nn.Linear(complexity * TARGET_LEN // 2 // 2, complexity), + nn.Linear(complexity * cfg.max_seq_len // 2 // 2, complexity), nn.LeakyReLU(alpha), nn.Dropout(drop_rate), ) @@ -67,7 +67,7 @@ def __init__(self, complexity): nn.Linear(complexity, 1), ) self.labels = nn.Sequential( - nn.Linear(complexity, DEPTH), + nn.Linear(complexity, cfg.k_label), ) self.optimizer = get_optimizer() # maybe it will help! diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index b3f3054b..40052313 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -23,10 +23,10 @@ def __init__(self, complexity, noise_size, w2v): self.main = nn.Sequential( # 1 layer Concatenate(1), - nn.Linear(NOISE_SIZE + DEPTH, TARGET_LEN // 2 // 2 * complexity), - nn.BatchNorm1d(TARGET_LEN // 2 // 2 * complexity), + nn.Linear(cfg.noise_size + cfg.k_label, cfg.max_seq_len // 2 // 2 * complexity), + nn.BatchNorm1d(cfg.max_seq_len // 2 // 2 * complexity), nn.LeakyReLU(alpha), - Reshape(complexity, TARGET_LEN // 2 // 2), + Reshape(complexity, cfg.max_seq_len // 2 // 2), # 2 layer MyConvLayerNorm(complexity, complexity, alpha=alpha), # 3 layer @@ -50,7 +50,7 @@ def __init__(self, complexity, noise_size, w2v): # adding/concatenating positional encoding PositionalEncoding( dim_pe=parameters.complexity, - max_len=TARGET_LEN, + max_len=cfg.max_seq_len, concatenate_pe=parameters.concatenate_pe, ), # 5 layer @@ -63,7 +63,7 @@ def __init__(self, complexity, noise_size, w2v): # adding/concatenating positional encoding PositionalEncoding( dim_pe=complexity, - max_len=TARGET_LEN, + max_len=cfg.max_seq_len, concatenate_pe=parameters.concatenate_pe, ), # 6 layer @@ -118,12 +118,12 @@ def __init__(self, complexity, noise_size, w2v): self.to(device) def forward(self, noise, target_labels): - target_labels = torch.nn.functional.one_hot(target_labels, num_classes=DEPTH) + target_labels = torch.nn.functional.one_hot(target_labels, num_classes=cfg.k_label) x = self.main([noise, target_labels]) return x def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): - noise = create_noise(num_samples, self.noise_size) + noise = create_noise(num_samples, self.noise_size, cfg.k_label) fakes = self.forward(*noise) fakes = fakes.detach().cpu().numpy() assert len(fakes.shape) == 3 diff --git a/run/run_fixem.py b/run/run_fixem.py index 202e5ffe..1fade2ed 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -41,7 +41,7 @@ noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 - +tips = '{} experiments' # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] @@ -49,9 +49,8 @@ w2v_embedding_size = [100, 100, 100, 100, 100, 100] w2v_window = 5 w2v_min_count = 30 -.w2v_workers = 30 +w2v_workers = 30 vocab_size = [5000, 0, 0, 5000, 0, 0] -target_len = [16, 40, 20, 16, 52] # ===CatGAN Param=== n_parent = 1 @@ -60,11 +59,6 @@ eval_type = 'Ra' temp_adpt = 'exp' d_out_mean = int(False) -embedding_size = [100, 512, 512, 100, 512, 512] -embedding_filename = 'w2v_{}.model'.format(embedding_size[job_id]) -w2v_window = 5 -w2v_min_count = 50 -w2v_workers = 1 # === Basic Param === data_shuffle = int(False) @@ -73,7 +67,7 @@ dis_init = 'uniform' samples_num = 10000 batch_size = 64 -max_seq_len = 20 +max_seq_len = [16, 40, 20, 16, 52] gen_lr = 0.01 gen_adv_lr = 1e-4 dis_lr = 1e-4 @@ -110,12 +104,6 @@ '--k_label', k_label, '--cuda', CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', ora_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--clas_pre_epoch', clas_pre_epoch, - '--adv_epoch', ADV_train_epoch, '--tips', tips.format(run_model[job_id]), # Oracle or Real @@ -123,17 +111,19 @@ '--dataset', dataset[job_id], '--vocab_size', vocab_size[job_id], + # W2V embeddings + '--w2v_embedding_size', w2v_embedding_size[job_id], + '--w2v_window', w2v_window, + '--w2v_min_count', w2v_min_count, + '--w2v_workers', w2v_workers, + # CatGAN Param '--n_parent', n_parent, '--loss_type', loss_type, '--mu_type', mu_type, '--eval_type', eval_type, '--temp_adpt', temp_adpt, - '--temperature', temperature[job_id], '--d_out_mean', d_out_mean, - '--lambda_fq', lambda_fq, - '--lambda_fd', lambda_fd, - '--eval_b_num', eval_b_num, # Basic Param '--shuffle', data_shuffle, @@ -142,7 +132,7 @@ '--dis_init', dis_init, '--samples_num', samples_num, '--batch_size', batch_size, - '--max_seq_len', max_seq_len, + '--max_seq_len', max_seq_len[job_id], '--gen_lr', gen_lr, '--gen_adv_lr', gen_adv_lr, '--dis_lr', dis_lr, diff --git a/utils/data_loader.py b/utils/data_loader.py index 90161c9e..ea9b5a59 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -141,14 +141,14 @@ def __init__(self, tokenized, labels, w2v, verbose, batch_size, batches_per_epoc self.labels = torch.tensor(labels, dtype=int) - self.vectors = [vectorize_sentence(tokens, w2v, padding_token = cfg.padding_token) for tokens in tokenized] + self.vectors = [vectorize_sentence(tokens, w2v, target_len=cfg.max_seq_len, padding_token = cfg.padding_token) for tokens in tokenized] self.vectors = np.stack(vectors, axis=0) self.vectors = torch.tensor(vectors, dtype=torch.float32) self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size - self.texts = set(" ".join(text[-TARGET_LEN:]) for text in texts) + self.texts = set(" ".join(text[-cfg.max_seq_len:]) for text in texts) if self.verbose: print('texts examples', [txt for txt in self.texts][:3]) diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py index b5327e42..b8b9ee20 100644 --- a/utils/nn_helpers.py +++ b/utils/nn_helpers.py @@ -9,10 +9,10 @@ import matplotlib.pyplot as plt -def create_noise(sample_size, noise_size): +def create_noise(sample_size, noise_size, k_label): return ( torch.randn(sample_size, noise_size).to(device), - torch.randint(0, DEPTH, (sample_size,)).sort().values.to(device), + torch.randint(0, k_label, (sample_size,)).sort().values.to(device), ) @@ -34,7 +34,7 @@ def get_optimizer(parameters, lr=0.0001, betas=(0.5, 0.999)): class PositionalEncoding(nn.Module): - def __init__(self, dim_pe: int, max_len: int = TARGET_LEN, concatenate_pe=False): + def __init__(self, dim_pe: int, max_len: int, concatenate_pe=False): super().__init__() position = torch.arange(max_len).unsqueeze(1) From 939ac46ae244278ce1fe19b99cd7fdab91edcec6 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 14:33:29 +0000 Subject: [PATCH 05/81] WIP --- config.py | 8 +++++--- instructor/real_data/fixem_instructor.py | 9 ++++----- models/FixemGAN_G.py | 2 +- run/run_fixem.py | 2 +- utils/create_embeddings.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/config.py b/config.py index 41e1086a..e3e96a90 100644 --- a/config.py +++ b/config.py @@ -183,6 +183,8 @@ samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) + +pretrain_embedding_path = pretrain_root + 'w2v_embedding_size{}.model'.format(w2v_embedding_size) signal_file = 'run_signal.txt' tips = '' @@ -205,7 +207,7 @@ def init_param(opt): lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ - w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers + w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -275,7 +277,7 @@ def init_param(opt): dis_hidden_dim = opt.dis_hidden_dim num_rep = opt.num_rep - w2v_embedding_size = opt.w2v_embeddging_size + w2v_embedding_size = opt.w2v_embedding_size w2v_window = opt.w2v_window w2v_min_count = opt.w2v_min_count w2v_workers = opt.w2v_workers @@ -336,7 +338,7 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - pretrain_embeddgin_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.w2v_embedding_size) + pretrain_embedding_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.w2v_embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index c61b6fc3..2da79f87 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -13,7 +13,7 @@ from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset from utils.nn_helpers import create_noise, number_of_parameters -from utils.create_embedding import EmbeddingsTrainer, load_embedding +from utils.create_embeddings import EmbeddingsTrainer, load_embedding from models.FixemGAN_G import Generator from models.FixemGAN_D import Discriminator @@ -33,12 +33,12 @@ class FixemGANInstructor(BasicInstructor): def __init__(self, cfg): super(FixemGANInstructor, self).__init__(cfg) # check if embeddings already exist for current oracle - if not os.path.exists(cfg.pretrain_embeddgin_path): + if not os.path.exists(cfg.pretrain_embedding_path): # train embedding on available dataset or oracle sources = list(Path(texts_data).glob('*.txt')) - EmbeddingsTrainer(sources, cfg.pretrain_embeddgin_path).make_embeddings() + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() - w2v = load_embedding(cfg.pretrain_embeddgin_path) + w2v = load_embedding(cfg.pretrain_embedding_path) if cfg.run_model == 'fixemgan': labels, train_data = zip(*[(0, line) for line in text_file_iterator(cfg.train_data)]) @@ -72,7 +72,6 @@ def __init__(self, cfg): def generator_train_one_batch(self): self.generator.optimizer.zero_grad() noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) - ones = label_ones(cfg.batch_size) fakes = self.generator(*noise) real_fake_predicts, label_predicts = self.discriminator(fakes) diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index 40052313..cd7894ad 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -97,7 +97,7 @@ def __init__(self, complexity, noise_size, w2v): MyLSTMLayerNorm( complexity, complexity//2, - ) if include_lstm else Dummy(),, + ) if include_lstm else Dummy(), # 11 layer MyConvTransposeLayer( complexity, diff --git a/run/run_fixem.py b/run/run_fixem.py index 1fade2ed..99f546d5 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -44,7 +44,7 @@ tips = '{} experiments' # ===Oracle or Real=== -if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] +if_real_data = [int(True), int(True), int(True), int(True), int(True)] dataset = ['mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] w2v_embedding_size = [100, 100, 100, 100, 100, 100] w2v_window = 5 diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py index 866a65fb..6bd4c5d9 100644 --- a/utils/create_embeddings.py +++ b/utils/create_embeddings.py @@ -1,7 +1,7 @@ from gensim.models import Word2Vec import config as cfg -from utils.text_process import text_file_iterator, get_tokenized_from_file +from utils.text_process import text_file_iterator class MultipleFilesIterator: From 6ba65d38eb4f80e81748ade850368592741589e8 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 14:55:23 +0000 Subject: [PATCH 06/81] WIP --- config.py | 14 +++++++++----- instructor/real_data/fixem_instructor.py | 12 ++++++------ main.py | 1 + run/run_fixem.py | 1 + utils/create_embeddings.py | 8 ++++---- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/config.py b/config.py index e3e96a90..4079a2f2 100644 --- a/config.py +++ b/config.py @@ -40,6 +40,7 @@ freeze_clas = False use_all_real_fake = False use_population = False +batches_per_epoch = 200 # ===Embedding=== w2v_embedding_size = 100 @@ -184,7 +185,10 @@ pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) -pretrain_embedding_path = pretrain_root + 'w2v_embedding_size{}.model'.format(w2v_embedding_size) +emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/real_data/' +pretrain_embedding_path = emebedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) +texts_pile = 'dataset/' # do not include testdata + signal_file = 'run_signal.txt' tips = '' @@ -207,7 +211,7 @@ def init_param(opt): lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ - w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path + w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -235,6 +239,7 @@ def init_param(opt): freeze_clas = opt.freeze_clas use_all_real_fake = opt.use_all_real_fake use_population = opt.use_population + batches_per_epoch = opt.batches_per_epoch samples_num = opt.samples_num vocab_size = opt.vocab_size @@ -325,8 +330,6 @@ def init_param(opt): cat_train_data = 'dataset/' + dataset + '_cat{}.txt' cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' - texts_data = 'dataset/' # do not include testdata - if max_seq_len == 40: oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt' multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt' @@ -338,7 +341,8 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - pretrain_embedding_path = pretrain_root + 'w2v_embedding_size{}.model'.format(opt.w2v_embedding_size) + emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/real_data/' + pretrain_embedding_path = emebedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 2da79f87..b6dfbcab 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -8,7 +8,7 @@ from tqdm import trange - +import config as cfg from instructor.real_data.instructor import BasicInstructor from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset @@ -30,12 +30,12 @@ class FixemGANInstructor(BasicInstructor): - def __init__(self, cfg): - super(FixemGANInstructor, self).__init__(cfg) + def __init__(self, opt): + super(FixemGANInstructor, self).__init__(opt) # check if embeddings already exist for current oracle if not os.path.exists(cfg.pretrain_embedding_path): # train embedding on available dataset or oracle - sources = list(Path(texts_data).glob('*.txt')) + sources = list(Path(texts_pile).glob('*.txt')) EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() w2v = load_embedding(cfg.pretrain_embedding_path) @@ -119,9 +119,9 @@ def _run(self): while self.one_more_batch_for_generator(generator_acc): generator_acc = self.generator_train_one_batch() - if cfg.run_model = 'fixemgan': + if cfg.run_model == 'fixemgan': scores = self.cal_metrics(fmt_str=True) - if cfg.run_model = 'cat_fixemgan': + if cfg.run_model == 'cat_fixemgan': scores = self.cal_metrics_with_label(fmt_str=True) print('epoch:', i, scores) diff --git a/main.py b/main.py index 4f73be19..c671fe4d 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,7 @@ def program_config(parser): parser.add_argument('--freeze_clas', default=cfg.freeze_clas, type=int) parser.add_argument('--use_all_real_fake', default=cfg.use_all_real_fake, type=int) parser.add_argument('--use_population', default=cfg.use_population, type=int) + parser.add_argument('--batches_per_epoch', default=cfg.batches_per_epoch, type=int) # Basic Train parser.add_argument('--samples_num', default=cfg.samples_num, type=int) diff --git a/run/run_fixem.py b/run/run_fixem.py index 99f546d5..fabcc9a3 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -124,6 +124,7 @@ '--eval_type', eval_type, '--temp_adpt', temp_adpt, '--d_out_mean', d_out_mean, + '--batches_per_epoch', batches_per_epoch, # Basic Param '--shuffle', data_shuffle, diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py index 6bd4c5d9..e6acd1d7 100644 --- a/utils/create_embeddings.py +++ b/utils/create_embeddings.py @@ -4,24 +4,24 @@ from utils.text_process import text_file_iterator -class MultipleFilesIterator: +class MultipleFilesEmbeddingIterator: def __init__(self, files): self.files = files def __iter__(self): for file in self.files: - yield from [cfg.padding_token] * 5 + text_file_iterator(file) + for tokens in text_file_iterator(file): + yield [cfg.padding_token] * 5 + tokens class EmbeddingsTrainer: def __init__(self, sources, save_filename): self.sources = sources - self.size = size self.save_filename = save_filename def make_embeddings(self): w2v = Word2Vec( - sentences=MultipleFilesIterator(self.sources), + sentences=MultipleFilesEmbeddingIterator(self.sources), size=cfg.w2v_embedding_size, window=cfg.w2v_window, min_count=cfg.w2v_min_count, From 8678509dea8ffa96b8d104dac9344b885705761f Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 15:29:42 +0000 Subject: [PATCH 07/81] WIP --- config.py | 7 ++++++- instructor/real_data/fixem_instructor.py | 4 +++- main.py | 2 ++ run/run_fixem.py | 6 +++++- utils/create_embeddings.py | 5 ++++- utils/data_loader.py | 6 +++--- utils/text_process.py | 3 ++- 7 files changed, 25 insertions(+), 8 deletions(-) diff --git a/config.py b/config.py index 4079a2f2..eb1ca59a 100644 --- a/config.py +++ b/config.py @@ -110,6 +110,7 @@ mem_slots = 1 # RelGAN-1 num_heads = 2 # RelGAN-2 head_size = 256 # RelGAN-256 +generator_complexity = 512 # ===Discriminator=== d_step = 5 # SeqGAN-50, LeakGAN-5 @@ -120,6 +121,7 @@ dis_embed_dim = 64 dis_hidden_dim = 64 num_rep = 64 # RelGAN +discriminator_complexity = 512 # ===log=== log_time_str = strftime("%m%d_%H%M_%S", localtime()) @@ -211,7 +213,8 @@ def init_param(opt): lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ - w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch + w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, + generator_complexity, discriminator_complexity if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -273,6 +276,7 @@ def init_param(opt): mem_slots = opt.mem_slots num_heads = opt.num_heads head_size = opt.head_size + generator_complexity = opt.generator_complexity d_step = opt.d_step d_epoch = opt.d_epoch @@ -281,6 +285,7 @@ def init_param(opt): dis_embed_dim = opt.dis_embed_dim dis_hidden_dim = opt.dis_hidden_dim num_rep = opt.num_rep + discriminator_complexity = opt.discriminator_complexity w2v_embedding_size = opt.w2v_embedding_size w2v_window = opt.w2v_window diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index b6dfbcab..2a9effb5 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -35,7 +35,9 @@ def __init__(self, opt): # check if embeddings already exist for current oracle if not os.path.exists(cfg.pretrain_embedding_path): # train embedding on available dataset or oracle - sources = list(Path(texts_pile).glob('*.txt')) + print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + print("Will train new one, it may take a while...") + sources = list(Path(cfg.texts_pile).glob('*.txt')) EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() w2v = load_embedding(cfg.pretrain_embedding_path) diff --git a/main.py b/main.py index c671fe4d..32c4054f 100644 --- a/main.py +++ b/main.py @@ -81,6 +81,7 @@ def program_config(parser): parser.add_argument('--mem_slots', default=cfg.mem_slots, type=int) parser.add_argument('--num_heads', default=cfg.num_heads, type=int) parser.add_argument('--head_size', default=cfg.head_size, type=int) + parser.add_argument('--generator_complexity', default=cfg.generator_complexity, type=int) # Discriminator parser.add_argument('--d_step', default=cfg.d_step, type=int) @@ -90,6 +91,7 @@ def program_config(parser): parser.add_argument('--dis_embed_dim', default=cfg.dis_embed_dim, type=int) parser.add_argument('--dis_hidden_dim', default=cfg.dis_hidden_dim, type=int) parser.add_argument('--num_rep', default=cfg.num_rep, type=int) + parser.add_argument('--discriminator_complexity', default=cfg.discriminator_complexity, type=int) # W2V embeddings parser.add_argument('--w2v_embedding_size', default=cfg.w2v_embedding_size, type=int) diff --git a/run/run_fixem.py b/run/run_fixem.py index fabcc9a3..2c8162ae 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -67,7 +67,7 @@ dis_init = 'uniform' samples_num = 10000 batch_size = 64 -max_seq_len = [16, 40, 20, 16, 52] +max_seq_len = [16, 40, 20, 16, 52, 36] gen_lr = 0.01 gen_adv_lr = 1e-4 dis_lr = 1e-4 @@ -81,12 +81,14 @@ mem_slots = 1 num_heads = 2 head_size = [512, 512, 512, 256, 256, 256] +generator_complexity = [256, 512, 512, 512, 512, 512] # ===Discriminator=== ADV_d_step = 3 dis_embed_dim = 64 dis_hidden_dim = 64 num_rep = 64 +discriminator_complexity = [512, 512, 512, 512, 512] # ===Metrics=== use_nll_oracle = int(True) @@ -147,12 +149,14 @@ '--mem_slots', mem_slots, '--num_heads', num_heads, '--head_size', head_size[job_id], + '--generator_complexity', generator_complexity[job_id] # Discriminator '--adv_d_step', ADV_d_step, '--dis_embed_dim', dis_embed_dim, '--dis_hidden_dim', dis_hidden_dim, '--num_rep', num_rep, + '--discriminator_complexity', discriminator_complexity[job_id] # Metrics '--use_nll_oracle', use_nll_oracle, diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py index e6acd1d7..4145553c 100644 --- a/utils/create_embeddings.py +++ b/utils/create_embeddings.py @@ -1,4 +1,6 @@ from gensim.models import Word2Vec +from tqdm import tqdm +from pathlib import Path import config as cfg from utils.text_process import text_file_iterator @@ -9,7 +11,7 @@ def __init__(self, files): self.files = files def __iter__(self): - for file in self.files: + for file in tqdm(self.files, desc='iterating files'): for tokens in text_file_iterator(file): yield [cfg.padding_token] * 5 + tokens @@ -27,6 +29,7 @@ def make_embeddings(self): min_count=cfg.w2v_min_count, workers=cfg.w2v_workers, ) + Path(self.save_filename).parents[0].mkdir(parents=True, exist_ok=True) w2v.save(self.save_filename) diff --git a/utils/data_loader.py b/utils/data_loader.py index ea9b5a59..e186ed90 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -142,13 +142,13 @@ def __init__(self, tokenized, labels, w2v, verbose, batch_size, batches_per_epoc self.labels = torch.tensor(labels, dtype=int) self.vectors = [vectorize_sentence(tokens, w2v, target_len=cfg.max_seq_len, padding_token = cfg.padding_token) for tokens in tokenized] - self.vectors = np.stack(vectors, axis=0) - self.vectors = torch.tensor(vectors, dtype=torch.float32) + self.vectors = np.stack(self.vectors, axis=0) + self.vectors = torch.tensor(self.vectors, dtype=torch.float32) self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size - self.texts = set(" ".join(text[-cfg.max_seq_len:]) for text in texts) + self.texts = set(" ".join(tokens[-cfg.max_seq_len:]) for tokens in tokenized) if self.verbose: print('texts examples', [txt for txt in self.texts][:3]) diff --git a/utils/text_process.py b/utils/text_process.py index 454c772e..97e4c704 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -352,7 +352,7 @@ def build_embedding_matrix(dataset): return embedding_matrix def pad_sequences( - sequence, target_len: int = 52, embedding_size: int = 300, padding_token = None + sequence, w2v, target_len: int = 52, embedding_size: int = 300, padding_token = None ) -> np.array: sequence = np.array(sequence) current_length = sequence.shape[0] @@ -367,6 +367,7 @@ def pad_sequences( def vectorize_sentence(tokens, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None): vectorized = pad_sequences( [w2v.wv[token] for token in tokens], + w2v, target_len=target_len, embedding_size=embedding_size, padding_token=padding_token, From 209427842801e3d885e430c445300e0150cc34e7 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 15:48:28 +0000 Subject: [PATCH 08/81] WIP --- config.py | 2 +- instructor/real_data/fixem_instructor.py | 6 ++++++ models/FixemGAN_D.py | 5 ++--- models/FixemGAN_G.py | 14 ++++++-------- run/run_fixem.py | 4 ++-- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/config.py b/config.py index eb1ca59a..5fb8e6b5 100644 --- a/config.py +++ b/config.py @@ -213,7 +213,7 @@ def init_param(opt): lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ - w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, + w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ generator_complexity, discriminator_complexity if_test = True if opt.if_test == 1 else False diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 2a9effb5..f0397b8d 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -66,6 +66,10 @@ def __init__(self, opt): number_of_parameters(self.generator.parameters()) ) + if cfg.CUDA:: + self.discriminator = self.discriminator.cuda() + self.generator = self.generator.cuda() + self.G_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, ) self.D_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2) @@ -114,6 +118,8 @@ def discriminator_train_one_batch(self, real_vector, labels): def _run(self): for i in trange(cfg.max_epochs): for labels, text_vector in self.train_data_supplier: + if cgf.CUDA: + labels, text_vector = labels.cuda(), text_vector.cuda() discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) generator_acc = 1 - 2 * (discriminator_acc - 0.5) diff --git a/models/FixemGAN_D.py b/models/FixemGAN_D.py index 3351c501..ec60d2b3 100644 --- a/models/FixemGAN_D.py +++ b/models/FixemGAN_D.py @@ -6,14 +6,13 @@ class Discriminator(CNNDiscriminator): def __init__(self, complexity): - super(Discriminator, self).__init__() alpha = 0.2 drop_rate = 0.0 include_transformer = False self.main = nn.Sequential( # 1 layer - MyConvLayer(EMBEDDING_SIZE, complexity, alpha=alpha, drop_rate=drop_rate), + MyConvLayer(cfg.w2v_embedding_size, complexity, alpha=alpha, drop_rate=drop_rate), # 2 layer MyConvLayer( complexity, @@ -30,7 +29,7 @@ def __init__(self, complexity): MyTransformerEncoderLayer( d_model=complexity, - n_layers=parameters.transformer_layers, + n_layers=3, ) if include_transformer else Dummy(), diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index cd7894ad..63751502 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -10,7 +10,6 @@ class Generator(LSTMGenerator): def __init__(self, complexity, noise_size, w2v): - super(Generator, self).__init__() alpha = 0.2 added_dim_pe = 0 include_batch_norm = True @@ -49,9 +48,9 @@ def __init__(self, complexity, noise_size, w2v): ), # adding/concatenating positional encoding PositionalEncoding( - dim_pe=parameters.complexity, + dim_pe=complexity, max_len=cfg.max_seq_len, - concatenate_pe=parameters.concatenate_pe, + concatenate_pe=False, ), # 5 layer MyConvLayerNorm( @@ -64,12 +63,12 @@ def __init__(self, complexity, noise_size, w2v): PositionalEncoding( dim_pe=complexity, max_len=cfg.max_seq_len, - concatenate_pe=parameters.concatenate_pe, + concatenate_pe=False, ), # 6 layer MyTransformerEncoderLayer( d_model=complexity + added_dim_pe, - n_layers=parameters.transformer_layers, + n_layers=3, ) if include_transformer else Dummy(), @@ -108,14 +107,13 @@ def __init__(self, complexity, noise_size, w2v): # 12 layer nn.Conv1d( complexity, - EMBEDDING_SIZE, + cfg.w2v_embedding_size, kernel_size=1, stride=1, padding=0, ), ) self.optimizer = get_optimizer() - self.to(device) def forward(self, noise, target_labels): target_labels = torch.nn.functional.one_hot(target_labels, num_classes=cfg.k_label) @@ -127,7 +125,7 @@ def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): fakes = self.forward(*noise) fakes = fakes.detach().cpu().numpy() assert len(fakes.shape) == 3 - return [recover_sentence(fake) for fake in fakes] + return [self.recover_sentence(fake) for fake in fakes] def recover_sentence(self, fake): fake = fake.T diff --git a/run/run_fixem.py b/run/run_fixem.py index 2c8162ae..18641aa9 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -149,14 +149,14 @@ '--mem_slots', mem_slots, '--num_heads', num_heads, '--head_size', head_size[job_id], - '--generator_complexity', generator_complexity[job_id] + '--generator_complexity', generator_complexity[job_id], # Discriminator '--adv_d_step', ADV_d_step, '--dis_embed_dim', dis_embed_dim, '--dis_hidden_dim', dis_hidden_dim, '--num_rep', num_rep, - '--discriminator_complexity', discriminator_complexity[job_id] + '--discriminator_complexity', discriminator_complexity[job_id], # Metrics '--use_nll_oracle', use_nll_oracle, From dddd9577fb48bd16e57973daf3de8cb92587c691 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 17:07:34 +0000 Subject: [PATCH 09/81] WIP --- config.py | 11 ++++++++- instructor/real_data/fixem_instructor.py | 23 +++++++++++------- main.py | 3 +++ models/FixemGAN_D.py | 11 +++++---- models/FixemGAN_G.py | 30 +++++++++++++++++------- run/run_fixem.py | 6 ++++- utils/data_loader.py | 9 ++++--- utils/gan_loss.py | 2 +- utils/nn_helpers.py | 4 ++-- 9 files changed, 69 insertions(+), 30 deletions(-) diff --git a/config.py b/config.py index 5fb8e6b5..c5c77510 100644 --- a/config.py +++ b/config.py @@ -40,7 +40,12 @@ freeze_clas = False use_all_real_fake = False use_population = False + +# ===FixemGAN=== batches_per_epoch = 200 +noise_size = 1000 +max_epochs = 20 +target_len = 40 # ===Embedding=== w2v_embedding_size = 100 @@ -214,7 +219,7 @@ def init_param(opt): multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity + generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -242,7 +247,11 @@ def init_param(opt): freeze_clas = opt.freeze_clas use_all_real_fake = opt.use_all_real_fake use_population = opt.use_population + batches_per_epoch = opt.batches_per_epoch + noise_size = opt.noise_size + max_epochs = opt.max_epochs + target_len = opt.target_len samples_num = opt.samples_num vocab_size = opt.vocab_size diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index f0397b8d..0c3f1421 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -1,4 +1,6 @@ import os +import random +from itertools import chain from pathlib import Path import numpy as np @@ -10,6 +12,7 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor +from utils.gan_loss import GANLoss from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset from utils.nn_helpers import create_noise, number_of_parameters @@ -60,24 +63,26 @@ def __init__(self, opt): "discriminator total tranable parameters:", number_of_parameters(self.discriminator.parameters()) ) - self.generator = Generator(cfg.generator_complexity, cgf.noise_size, w2v) + self.generator = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) print( "generator total tranable parameters:", number_of_parameters(self.generator.parameters()) ) - if cfg.CUDA:: + if cfg.CUDA: self.discriminator = self.discriminator.cuda() self.generator = self.generator.cuda() - self.G_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, ) - self.D_criterion = GANLoss(cfg.run_model, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2) + self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, ) + self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2) self.all_metrics = [self.bleu, self.self_bleu] def generator_train_one_batch(self): self.generator.optimizer.zero_grad() noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) fakes = self.generator(*noise) real_fake_predicts, label_predicts = self.discriminator(fakes) @@ -95,18 +100,20 @@ def discriminator_train_one_batch(self, real_vector, labels): this_batch_size = real_vector.shape[0] # create input - noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) fake = self.generator(*noise).detach() text_input_vectors = torch.cat((real_vector, fake)) # optmizer step - discriminator.optimizer.zero_grad() + self.discriminator.optimizer.zero_grad() real_fake_predicts, label_predicts = self.discriminator(text_input_vectors) loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) loss.backward() - discriminator.optimizer.step() + self.discriminator.optimizer.step() - discriminator_acc = torch.cat( + self.discriminator_acc = torch.cat( ( real_fake_predicts.chunk(2)[0] > 0.5, real_fake_predicts.chunk(2)[1] < 0.5 diff --git a/main.py b/main.py index 32c4054f..43b8d528 100644 --- a/main.py +++ b/main.py @@ -44,6 +44,9 @@ def program_config(parser): parser.add_argument('--use_all_real_fake', default=cfg.use_all_real_fake, type=int) parser.add_argument('--use_population', default=cfg.use_population, type=int) parser.add_argument('--batches_per_epoch', default=cfg.batches_per_epoch, type=int) + parser.add_argument('--noise_size', default=cfg.noise_size, type=int) + parser.add_argument('--max_epochs', default=cfg.max_epochs, type=int) + parser.add_argument('--target_len', default=cfg.target_len, type=int) # Basic Train parser.add_argument('--samples_num', default=cfg.samples_num, type=int) diff --git a/models/FixemGAN_D.py b/models/FixemGAN_D.py index ec60d2b3..6f609f76 100644 --- a/models/FixemGAN_D.py +++ b/models/FixemGAN_D.py @@ -1,11 +1,12 @@ import torch.nn as nn -from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten - +import config as cfg +from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten, Dummy from models.discriminator import CNNDiscriminator -class Discriminator(CNNDiscriminator): +class Discriminator(nn.Module): def __init__(self, complexity): + super(Discriminator, self).__init__() alpha = 0.2 drop_rate = 0.0 include_transformer = False @@ -57,7 +58,7 @@ def __init__(self, complexity): ), # 8 layer Flatten(), - nn.Linear(complexity * cfg.max_seq_len // 2 // 2, complexity), + nn.Linear(complexity * cfg.target_len // 2 // 2, complexity), nn.LeakyReLU(alpha), nn.Dropout(drop_rate), ) @@ -68,7 +69,7 @@ def __init__(self, complexity): self.labels = nn.Sequential( nn.Linear(complexity, cfg.k_label), ) - self.optimizer = get_optimizer() + self.optimizer = get_optimizer(self.parameters()) # maybe it will help! # self.init_params() diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index 63751502..3af9e87b 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -1,7 +1,18 @@ from dataclasses import dataclass +import torch import torch.nn as nn -from utils.nn_helpers import get_optimizer, create_noise, Concatenate, Reshape, MyConvLayerNorm, MyConvTransposeLayer, PositionalEncoding, MyLSTMLayerNorm +from utils.nn_helpers import ( + get_optimizer, + create_noise, + Concatenate, + Reshape, + MyConvLayerNorm, + MyConvTransposeLayer, + PositionalEncoding, + MyLSTMLayerNorm, + Dummy, +) import config as cfg from models.generator import LSTMGenerator @@ -9,7 +20,8 @@ class Generator(LSTMGenerator): - def __init__(self, complexity, noise_size, w2v): + def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): + super(Generator, self).__init__(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.target_len, cfg.padding_idx) alpha = 0.2 added_dim_pe = 0 include_batch_norm = True @@ -17,15 +29,15 @@ def __init__(self, complexity, noise_size, w2v): include_lstm = True self.noise_size = noise_size self.w2v = w2v - self.embedding_size = embedding_size + self.embedding_size = w2v_embedding_size self.main = nn.Sequential( # 1 layer Concatenate(1), - nn.Linear(cfg.noise_size + cfg.k_label, cfg.max_seq_len // 2 // 2 * complexity), - nn.BatchNorm1d(cfg.max_seq_len // 2 // 2 * complexity), + nn.Linear(cfg.noise_size + cfg.k_label, cfg.target_len // 2 // 2 * complexity), + nn.BatchNorm1d(cfg.target_len // 2 // 2 * complexity), nn.LeakyReLU(alpha), - Reshape(complexity, cfg.max_seq_len // 2 // 2), + Reshape(complexity, cfg.target_len // 2 // 2), # 2 layer MyConvLayerNorm(complexity, complexity, alpha=alpha), # 3 layer @@ -49,7 +61,7 @@ def __init__(self, complexity, noise_size, w2v): # adding/concatenating positional encoding PositionalEncoding( dim_pe=complexity, - max_len=cfg.max_seq_len, + max_len=cfg.target_len, concatenate_pe=False, ), # 5 layer @@ -62,7 +74,7 @@ def __init__(self, complexity, noise_size, w2v): # adding/concatenating positional encoding PositionalEncoding( dim_pe=complexity, - max_len=cfg.max_seq_len, + max_len=cfg.target_len, concatenate_pe=False, ), # 6 layer @@ -113,7 +125,7 @@ def __init__(self, complexity, noise_size, w2v): padding=0, ), ) - self.optimizer = get_optimizer() + self.optimizer = get_optimizer(self.parameters()) def forward(self, noise, target_labels): target_labels = torch.nn.functional.one_hot(target_labels, num_classes=cfg.k_label) diff --git a/run/run_fixem.py b/run/run_fixem.py index 18641aa9..f2113009 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -67,7 +67,7 @@ dis_init = 'uniform' samples_num = 10000 batch_size = 64 -max_seq_len = [16, 40, 20, 16, 52, 36] +target_len = [16, 40, 20, 16, 52, 36] gen_lr = 0.01 gen_adv_lr = 1e-4 dis_lr = 1e-4 @@ -126,7 +126,11 @@ '--eval_type', eval_type, '--temp_adpt', temp_adpt, '--d_out_mean', d_out_mean, + + '--max_epochs', max_epochs, '--batches_per_epoch', batches_per_epoch, + '--noise_size', noise_size, + '--target_len', target_len, # Basic Param '--shuffle', data_shuffle, diff --git a/utils/data_loader.py b/utils/data_loader.py index e186ed90..cac8758a 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -8,7 +8,10 @@ # Copyrights (C) 2018. All Rights Reserved. import random + +import torch from torch.utils.data import Dataset, DataLoader +from tqdm import trange import config as cfg from utils.text_process import * @@ -164,11 +167,11 @@ def __iter__(self): index += self.batch_size if index > len(self): # concatenating beginning of self.vectors - yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])).to(device), - torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).to(device)) + yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), + torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).) index = index % len(self) else: - yield self.labels[index - self.batch_size: index].to(device), self.vectors[index - self.batch_size: index].to(device) + yield self.labels[index - self.batch_size: index], self.vectors[index - self.batch_size: index] def __len__(self): diff --git a/utils/gan_loss.py b/utils/gan_loss.py index bb373839..8efb7b46 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -45,7 +45,7 @@ def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_ self.loss = nn.BCEWithLogitsLoss() elif loss_mode in ['wgan', 'hinge']: self.loss = None - elif loss_mode == 'fixem': + elif loss_mode == 'fixemgan': self.real_fake_criterion = nn.BCEWithLogitsLoss() self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) self.diversity_criterion = DiversityLoss() diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py index b8b9ee20..2fbc9587 100644 --- a/utils/nn_helpers.py +++ b/utils/nn_helpers.py @@ -11,8 +11,8 @@ def create_noise(sample_size, noise_size, k_label): return ( - torch.randn(sample_size, noise_size).to(device), - torch.randint(0, k_label, (sample_size,)).sort().values.to(device), + torch.randn(sample_size, noise_size), + torch.randint(0, k_label, (sample_size,)), ) From 03570c8857de853721acfe17d4b993c0e69e8535 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 27 Jan 2023 19:07:37 +0000 Subject: [PATCH 10/81] fixemgan implementation for uncategorized type of data --- instructor/real_data/fixem_instructor.py | 46 +++++++++++++++--------- models/FixemGAN_G.py | 4 ++- run/run_fixem.py | 3 +- utils/data_loader.py | 14 +++----- utils/gan_loss.py | 2 +- utils/nn_helpers.py | 1 + 6 files changed, 41 insertions(+), 29 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 0c3f1421..daaf29f4 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -23,13 +23,19 @@ # TO DO: # 1. train embedding if not exists (if oracle, then always retrain ) -# 2. create data generator (categorical and non categorical) (based on given dataset) -# 3. create disc and gen # 4. train epochs and each 10 epochs print metrics -# 5. show metrics # 6. save? or save each 10 epochs - +# fix bleu score +# add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) +# logger +# cat_fixemgan +# oracle +# cat_oracle +# make run_fixem clean + +# afterwards: # chack target real/fake to be right (Uniform or const) +# random data portion generator? class FixemGANInstructor(BasicInstructor): @@ -56,7 +62,7 @@ def __init__(self, opt): ) ) - self.train_data_supplier = DataSupplier(train_data, labels, w2v, True, cfg.batch_size, cfg.batches_per_epoch) + self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) self.discriminator = Discriminator(cfg.discriminator_complexity) print( @@ -73,20 +79,21 @@ def __init__(self, opt): self.discriminator = self.discriminator.cuda() self.generator = self.generator.cuda() - self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, ) - self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2) + self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) + self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) self.all_metrics = [self.bleu, self.self_bleu] def generator_train_one_batch(self): self.generator.optimizer.zero_grad() - noise = create_noise(cfg.batch_size, cfg.noise_size. cfg.k_label) + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) if cfg.CUDA: noise = tuple(tt.cuda() for tt in noise) fakes = self.generator(*noise) real_fake_predicts, label_predicts = self.discriminator(fakes) - loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, fakes) + loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) + loss.backward() self.generator.optimizer.step() @@ -113,11 +120,14 @@ def discriminator_train_one_batch(self, real_vector, labels): loss.backward() self.discriminator.optimizer.step() - self.discriminator_acc = torch.cat( - ( - real_fake_predicts.chunk(2)[0] > 0.5, - real_fake_predicts.chunk(2)[1] < 0.5 - ) + discriminator_acc = float( + torch.tensor( + torch.cat(( + real_fake_predicts.chunk(2)[0] > 0.5, + real_fake_predicts.chunk(2)[1] < 0.5 + )), + dtype = float, + ).mean() ) return discriminator_acc @@ -125,7 +135,7 @@ def discriminator_train_one_batch(self, real_vector, labels): def _run(self): for i in trange(cfg.max_epochs): for labels, text_vector in self.train_data_supplier: - if cgf.CUDA: + if cfg.CUDA: labels, text_vector = labels.cuda(), text_vector.cuda() discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) @@ -139,7 +149,11 @@ def _run(self): if cfg.run_model == 'cat_fixemgan': scores = self.cal_metrics_with_label(fmt_str=True) - print('epoch:', i, scores) + samples = self.generator.sample(20, 20) + for sample in samples: + print(sample) + if (i + 1) % 10 == 0: + print('epoch:', i, scores) def one_more_batch_for_generator( diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index 3af9e87b..ef69a71c 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -134,6 +134,8 @@ def forward(self, noise, target_labels): def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): noise = create_noise(num_samples, self.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) fakes = self.forward(*noise) fakes = fakes.detach().cpu().numpy() assert len(fakes.shape) == 3 @@ -143,7 +145,7 @@ def recover_sentence(self, fake): fake = fake.T tokens = [] for token_vector in fake: - token = self.w2v.wv.most_similar([token_vec])[0][0] + token = self.w2v.wv.most_similar([token_vector])[0][0] if token == cfg.padding_token: continue tokens.append(token) diff --git a/run/run_fixem.py b/run/run_fixem.py index f2113009..7fad0e39 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -130,7 +130,7 @@ '--max_epochs', max_epochs, '--batches_per_epoch', batches_per_epoch, '--noise_size', noise_size, - '--target_len', target_len, + '--target_len', target_len[job_id], # Basic Param '--shuffle', data_shuffle, @@ -139,7 +139,6 @@ '--dis_init', dis_init, '--samples_num', samples_num, '--batch_size', batch_size, - '--max_seq_len', max_seq_len[job_id], '--gen_lr', gen_lr, '--gen_adv_lr', gen_adv_lr, '--dis_lr', dis_lr, diff --git a/utils/data_loader.py b/utils/data_loader.py index cac8758a..6f86a4f6 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -132,10 +132,7 @@ def prepare(self, pos_samples, neg_samples, gpu=False): class DataSupplier: - def __init__(self, tokenized, labels, w2v, verbose, batch_size, batches_per_epoch): - self.verbose=verbose - - + def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): labels, tokenized = zip(*[ (label, tokens) for label, tokens in zip(labels, tokenized) @@ -144,31 +141,30 @@ def __init__(self, tokenized, labels, w2v, verbose, batch_size, batches_per_epoc self.labels = torch.tensor(labels, dtype=int) - self.vectors = [vectorize_sentence(tokens, w2v, target_len=cfg.max_seq_len, padding_token = cfg.padding_token) for tokens in tokenized] + self.vectors = [vectorize_sentence(tokens, w2v, target_len=cfg.target_len, padding_token = cfg.padding_token) for tokens in tokenized] self.vectors = np.stack(self.vectors, axis=0) self.vectors = torch.tensor(self.vectors, dtype=torch.float32) self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size - self.texts = set(" ".join(tokens[-cfg.max_seq_len:]) for tokens in tokenized) + self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) if self.verbose: print('texts examples', [txt for txt in self.texts][:3]) def __iter__(self): - batch_iterator = trange(self.batches_per_epoch) if self.verbose else range(self.batches_per_epoch) permutation = torch.randperm(len(self)) self.vectors = self.vectors[permutation] self.labels = self.labels[permutation] - for _ in batch_iterator: + for _ in range(self.batches_per_epoch): index = 0 index += self.batch_size if index > len(self): # concatenating beginning of self.vectors yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), - torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)])).) + torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)]))) index = index % len(self) else: yield self.labels[index - self.batch_size: index], self.vectors[index - self.batch_size: index] diff --git a/utils/gan_loss.py b/utils/gan_loss.py index 8efb7b46..bb373839 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -45,7 +45,7 @@ def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_ self.loss = nn.BCEWithLogitsLoss() elif loss_mode in ['wgan', 'hinge']: self.loss = None - elif loss_mode == 'fixemgan': + elif loss_mode == 'fixem': self.real_fake_criterion = nn.BCEWithLogitsLoss() self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) self.diversity_criterion = DiversityLoss() diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py index 2fbc9587..ac25ccff 100644 --- a/utils/nn_helpers.py +++ b/utils/nn_helpers.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn from torch import Tensor From 16d10ea1705b9d666e41b901775c533e3c39c848 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 29 Jan 2023 15:29:14 +0000 Subject: [PATCH 11/81] cat fixem WIP --- instructor/real_data/fixem_instructor.py | 68 +++++++++++++++++------- instructor/real_data/instructor.py | 2 +- main.py | 3 +- models/FixemGAN_G.py | 5 +- run/run_fixem.py | 2 +- utils/data_loader.py | 3 +- 6 files changed, 57 insertions(+), 26 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index daaf29f4..c4fc0e2a 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -22,16 +22,17 @@ # TO DO: -# 1. train embedding if not exists (if oracle, then always retrain ) -# 4. train epochs and each 10 epochs print metrics -# 6. save? or save each 10 epochs -# fix bleu score -# add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) -# logger -# cat_fixemgan -# oracle -# cat_oracle -# make run_fixem clean +# 2. test cat gan +# 1. test oracle (# 1. train embedding if not exists (if oracle, then always retrain )) +# 3. train epochs and each 10 epochs print metrics +# 4. save? or save each 10 epochs +# 5. fix bleu score +# 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) +# 7. logger +# 8. cat_fixemgan +# 9. oracle +# 10. cat_oracle +# 11. make run_fixem clean # afterwards: # chack target real/fake to be right (Uniform or const) @@ -120,14 +121,14 @@ def discriminator_train_one_batch(self, real_vector, labels): loss.backward() self.discriminator.optimizer.step() + real_fake_predicts = real_fake_predicts.clone().detach() + real_fake_predicts = real_fake_predicts.chunk(2) #splitting to realand fake parks + discriminator_acc = float( - torch.tensor( torch.cat(( - real_fake_predicts.chunk(2)[0] > 0.5, - real_fake_predicts.chunk(2)[1] < 0.5 - )), - dtype = float, - ).mean() + real_fake_predicts[0] > 0.5, + real_fake_predicts[1] < 0.5 + )).mean(dtype=float) ) return discriminator_acc @@ -144,15 +145,16 @@ def _run(self): while self.one_more_batch_for_generator(generator_acc): generator_acc = self.generator_train_one_batch() - if cfg.run_model == 'fixemgan': - scores = self.cal_metrics(fmt_str=True) - if cfg.run_model == 'cat_fixemgan': - scores = self.cal_metrics_with_label(fmt_str=True) samples = self.generator.sample(20, 20) for sample in samples: print(sample) + if (i + 1) % 10 == 0: + if cfg.run_model == 'fixemgan': + scores = self.cal_metrics(fmt_str=True) + if cfg.run_model == 'cat_fixemgan': + scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) print('epoch:', i, scores) @@ -166,6 +168,32 @@ def one_more_batch_for_generator( return False + def cal_metrics_with_label(self, label_i, fmt_str=False): + assert type(label_i) == int, 'missing label' + with torch.no_grad(): + # Prepare data for evaluation + # eval_samples = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + # gen_data = GenDataIter(eval_samples) + # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) + gen_tokens_s = self.generator.sample(200, 200, label_i=label_i) + # clas_data = CatClasDataIter([eval_samples], label_i) + + # Reset metrics + self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) + # self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) + # self.nll_div.reset(self.gen, gen_data.loader, label_i) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + # self.clas_acc.reset(self.clas, clas_data.loader) + # self.ppl.reset(gen_tokens) + + if fmt_str: + return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + + return [metric.get_score() for metric in self.all_metrics] + + def cal_metrics(self, fmt_str=False): """ Calculate metrics diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 622c382a..7b787215 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -4,7 +4,7 @@ # @FileName : instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np diff --git a/main.py b/main.py index 43b8d528..04f32a0e 100644 --- a/main.py +++ b/main.py @@ -175,7 +175,8 @@ def program_config(parser): 'catgan': CatGANInstructor, 'dgsan': DGSANInstructor, 'cot': CoTInstructor, - 'fixemgan': FixemGANInstructor + 'fixemgan': FixemGANInstructor, + 'cat_fixemgan': FixemGANInstructor } inst = instruction_dict[cfg.run_model](opt) diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index ef69a71c..959eefcc 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -132,8 +132,11 @@ def forward(self, noise, target_labels): x = self.main([noise, target_labels]) return x - def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): + def sample(self, num_samples, batch_size, label_i = 'random', start_letter=cfg.start_letter): noise = create_noise(num_samples, self.noise_size, cfg.k_label) + if label_i != 'random': + noise = (noise[0], torch.tensor(label_i).expand_as(noise[1])) + if cfg.CUDA: noise = tuple(tt.cuda() for tt in noise) fakes = self.forward(*noise) diff --git a/run/run_fixem.py b/run/run_fixem.py index 7fad0e39..d5d264ad 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -46,7 +46,7 @@ # ===Oracle or Real=== if_real_data = [int(True), int(True), int(True), int(True), int(True)] dataset = ['mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [100, 100, 100, 100, 100, 100] +w2v_embedding_size = [512, 100, 100, 100, 100, 100] w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 diff --git a/utils/data_loader.py b/utils/data_loader.py index 6f86a4f6..7e7dc0ee 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -149,8 +149,7 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): self.batch_size = batch_size self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) - if self.verbose: - print('texts examples', [txt for txt in self.texts][:3]) + print('dataset random texts examples', [txt for txt in self.texts][:3]) def __iter__(self): From dc814483471cf902092407a9613c2c9674b5af33 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 29 Jan 2023 19:38:38 +0000 Subject: [PATCH 12/81] oracle implementation WIP --- instructor/real_data/fixem_instructor.py | 4 ++-- run/run_fixem.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index c4fc0e2a..15dfe870 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -22,8 +22,8 @@ # TO DO: -# 2. test cat gan -# 1. test oracle (# 1. train embedding if not exists (if oracle, then always retrain )) +# 1. test oracle +# 1. train embedding if not exists (if oracle, then always retrain )) # 3. train epochs and each 10 epochs print metrics # 4. save? or save each 10 epochs # 5. fix bleu score diff --git a/run/run_fixem.py b/run/run_fixem.py index d5d264ad..7706ee06 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -34,7 +34,7 @@ # CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) -run_model = ['fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'cat_fixemgan', 'cat_fixemgan'] +run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'cat_fixemgan', 'cat_fixemgan'] k_label = 2 CUDA = int(True) batch_size = 32 @@ -45,8 +45,8 @@ # ===Oracle or Real=== if_real_data = [int(True), int(True), int(True), int(True), int(True)] -dataset = ['mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [512, 100, 100, 100, 100, 100] +dataset = ['oracle', 'mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] +w2v_embedding_size = [100, 512, 100, 100, 100, 100] w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 @@ -81,7 +81,7 @@ mem_slots = 1 num_heads = 2 head_size = [512, 512, 512, 256, 256, 256] -generator_complexity = [256, 512, 512, 512, 512, 512] +generator_complexity = [768, 512, 512, 512, 512, 512] # ===Discriminator=== ADV_d_step = 3 From 7dda0c82ed5cef8a6cd2dff2b841e3c13cf1bce6 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 29 Jan 2023 20:38:58 +0000 Subject: [PATCH 13/81] oracle implementation WIP --- config.py | 6 +- instructor/oracle_data/fixem_instructor.py | 227 +++++++++++++++++++++ instructor/real_data/fixem_instructor.py | 2 +- main.py | 2 + run/run_fixem.py | 2 + 5 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 instructor/oracle_data/fixem_instructor.py diff --git a/config.py b/config.py index c5c77510..728846c7 100644 --- a/config.py +++ b/config.py @@ -52,6 +52,7 @@ w2v_window = 5 w2v_min_count = 1 w2v_workers = 1 +w2v_samples_num = 5_000_000 # ===Oracle or Real, type=== if_real_data = False # if use real data @@ -219,7 +220,7 @@ def init_param(opt): multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len + generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -300,6 +301,7 @@ def init_param(opt): w2v_window = opt.w2v_window w2v_min_count = opt.w2v_min_count w2v_workers = opt.w2v_workers + w2v_samples_num = opt.w2v_samples_num use_nll_oracle = True if opt.use_nll_oracle == 1 else False use_nll_gen = True if opt.use_nll_gen == 1 else False @@ -355,7 +357,7 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/real_data/' + emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' pretrain_embedding_path = emebedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py new file mode 100644 index 00000000..5ae5617b --- /dev/null +++ b/instructor/oracle_data/fixem_instructor.py @@ -0,0 +1,227 @@ +import os +import random +from itertools import chain + +from pathlib import Path +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +import torchtext +from tqdm import trange + + +import config as cfg +from instructor.oracle_data.instructor import BasicInstructor +from instructor.real_data.fixem_instructor import FixemGANInstructor +from utils.gan_loss import GANLoss +from utils.text_process import text_file_iterator +from utils.data_loader import DataSupplier, GANDataset +from utils.nn_helpers import create_noise, number_of_parameters +from utils.create_embeddings import EmbeddingsTrainer, load_embedding +from models.FixemGAN_G import Generator +from models.FixemGAN_D import Discriminator + + +# TO DO: +# 1. test oracle +# 1. train embedding if not exists (if oracle, then always retrain )) +# 3. train epochs and each 10 epochs print metrics +# 4. save? or save each 10 epochs +# 5. fix bleu score +# 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) +# 7. logger +# 8. cat_fixemgan +# 9. oracle +# 10. cat_oracle +# 11. make run_fixem clean + +# afterwards: +# chack target real/fake to be right (Uniform or const) +# random data portion generator? + + +class FixemGANInstructor(BasicInstructor, FixemGANInstructor): + def __init__(self, opt): + super(FixemGANInstructor, self).__init__(opt) + # check if embeddings already exist for current oracle + + if not os.path.exists(cfg.pretrain_embedding_path): + # train embedding on available dataset or oracle + print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + print("Will train new one, it may take a while...") + giant_samples = self.oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) + file_name + + with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: + for sample in tqdm(giant_samples): + f.write(" ".join(str(int(idx)) for idx in sample)) + f.write("\n") + + sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() + + w2v = load_embedding(cfg.pretrain_embedding_path) + + if cfg.run_model == 'fixemgan': + labels, train_data = zip(*[(0, line) for line in text_file_iterator(cfg.train_data)]) + + if cfg.run_model == 'cat_fixemgan': + labels, train_data = zip( + *chain( + *[[(i, line) for line in text_file_iterator(cfg.cat_train_data.format(i))] + for i in range(cfg.k_label)] + ) + ) + + self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) + + self.discriminator = Discriminator(cfg.discriminator_complexity) + print( + "discriminator total tranable parameters:", + number_of_parameters(self.discriminator.parameters()) + ) + self.generator = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) + print( + "generator total tranable parameters:", + number_of_parameters(self.generator.parameters()) + ) + + if cfg.CUDA: + self.discriminator = self.discriminator.cuda() + self.generator = self.generator.cuda() + + self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) + self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) + + self.all_metrics = [self.bleu, self.self_bleu] + + def generator_train_one_batch(self): + self.generator.optimizer.zero_grad() + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) + fakes = self.generator(*noise) + + real_fake_predicts, label_predicts = self.discriminator(fakes) + loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) + + loss.backward() + self.generator.optimizer.step() + + generator_acc = float( + np.array(real_fake_predicts.detach().cpu().numpy() > 0.5, dtype=int).mean() + ) + return generator_acc + + def discriminator_train_one_batch(self, real_vector, labels): + # important to have equal batch size for fake and real vectors + this_batch_size = real_vector.shape[0] + + # create input + noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) + if cfg.CUDA: + noise = tuple(tt.cuda() for tt in noise) + fake = self.generator(*noise).detach() + text_input_vectors = torch.cat((real_vector, fake)) + + # optmizer step + self.discriminator.optimizer.zero_grad() + real_fake_predicts, label_predicts = self.discriminator(text_input_vectors) + loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) + loss.backward() + self.discriminator.optimizer.step() + + real_fake_predicts = real_fake_predicts.clone().detach() + real_fake_predicts = real_fake_predicts.chunk(2) #splitting to realand fake parks + + discriminator_acc = float( + torch.cat(( + real_fake_predicts[0] > 0.5, + real_fake_predicts[1] < 0.5 + )).mean(dtype=float) + ) + return discriminator_acc + + + def _run(self): + for i in trange(cfg.max_epochs): + for labels, text_vector in self.train_data_supplier: + if cfg.CUDA: + labels, text_vector = labels.cuda(), text_vector.cuda() + discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) + + generator_acc = 1 - 2 * (discriminator_acc - 0.5) + # run the generator until generator acc not get high enought + while self.one_more_batch_for_generator(generator_acc): + generator_acc = self.generator_train_one_batch() + + + samples = self.generator.sample(20, 20) + for sample in samples: + print(sample) + + if (i + 1) % 10 == 0: + if cfg.run_model == 'fixemgan': + scores = self.cal_metrics(fmt_str=True) + if cfg.run_model == 'cat_fixemgan': + scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) + print('epoch:', i, scores) + + + def one_more_batch_for_generator( + self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 + ): + generator_acc = min(leave_in_generator_max, generator_acc) + generator_acc = max(leave_in_generator_min, generator_acc) + if random.random() > generator_acc: + return True + return False + + + def cal_metrics_with_label(self, label_i, fmt_str=False): + assert type(label_i) == int, 'missing label' + with torch.no_grad(): + # Prepare data for evaluation + # eval_samples = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + # gen_data = GenDataIter(eval_samples) + # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) + gen_tokens_s = self.generator.sample(200, 200, label_i=label_i) + # clas_data = CatClasDataIter([eval_samples], label_i) + + # Reset metrics + self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) + # self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) + # self.nll_div.reset(self.gen, gen_data.loader, label_i) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + # self.clas_acc.reset(self.clas, clas_data.loader) + # self.ppl.reset(gen_tokens) + + if fmt_str: + return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + + return [metric.get_score() for metric in self.all_metrics] + + + def cal_metrics(self, fmt_str=False): + """ + Calculate metrics + :param fmt_str: if return format string for logging + """ + with torch.no_grad(): + # Prepare data for evaluation + gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens_s = self.generator.sample(200, 200) + + # Reset metrics + self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) + # self.nll_gen.reset(self.gen, self.train_data.loader) + # self.nll_div.reset(self.gen, gen_data.loader) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + # self.ppl.reset(gen_tokens) + + if fmt_str: + return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + else: + return [metric.get_score() for metric in self.all_metrics] diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 15dfe870..81cb57d3 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -42,7 +42,7 @@ class FixemGANInstructor(BasicInstructor): def __init__(self, opt): super(FixemGANInstructor, self).__init__(opt) - # check if embeddings already exist for current oracle + # check if embeddings already exist if not os.path.exists(cfg.pretrain_embedding_path): # train embedding on available dataset or oracle print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") diff --git a/main.py b/main.py index 04f32a0e..9f745194 100644 --- a/main.py +++ b/main.py @@ -101,6 +101,7 @@ def program_config(parser): parser.add_argument('--w2v_window', default=cfg.w2v_window, type=int) parser.add_argument('--w2v_min_count', default=cfg.w2v_min_count, type=int) parser.add_argument('--w2v_workers', default=cfg.w2v_workers, type=int) + parser.add_argument('--w2v_samples_num', default=cfg.w2v_samples_num, type=int) # Metrics parser.add_argument('--use_nll_oracle', default=cfg.use_nll_oracle, type=int) @@ -162,6 +163,7 @@ def program_config(parser): from instructor.oracle_data.catgan_instructor import CatGANInstructor from instructor.oracle_data.dgsan_instructor import DGSANInstructor from instructor.oracle_data.cot_instructor import CoTInstructor + from instructor.oracle_data.fixem_instructor import FixemGANInstructor instruction_dict = { 'seqgan': SeqGANInstructor, diff --git a/run/run_fixem.py b/run/run_fixem.py index 7706ee06..c4efc4d1 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -50,6 +50,7 @@ w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 +w2v_samples_num = 5_000_000 vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== @@ -118,6 +119,7 @@ '--w2v_window', w2v_window, '--w2v_min_count', w2v_min_count, '--w2v_workers', w2v_workers, + '--w2v_samples_num', w2v_samples_num, # CatGAN Param '--n_parent', n_parent, From cf8508434300b6f55027e91f27f1acd604cacc16 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 29 Jan 2023 22:17:17 +0000 Subject: [PATCH 14/81] oracle fixem implementation WIP --- config.py | 4 +- instructor/oracle_data/fixem_instructor.py | 188 ++------------------- run/run_fixem.py | 2 +- utils/data_utils.py | 2 +- utils/helpers.py | 14 ++ 5 files changed, 33 insertions(+), 177 deletions(-) diff --git a/config.py b/config.py index 728846c7..5d5cbd98 100644 --- a/config.py +++ b/config.py @@ -341,9 +341,9 @@ def init_param(opt): save_samples_root = save_root + 'samples/' save_model_root = save_root + 'models/' - train_data = 'dataset/' + dataset + '.txt' + train_data = 'dataset/' + dataset + '.txt' if if_real_data else 'pretrain/oracle_data/' + dataset + '.txt' test_data = 'dataset/testdata/' + dataset + '_test.txt' - cat_train_data = 'dataset/' + dataset + '_cat{}.txt' + cat_train_data = 'dataset/' + dataset + '_cat{}.txt' if if_real_data else 'pretrain/oracle_data/' + dataset + '_cat{}.txt' cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' if max_seq_len == 40: diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 5ae5617b..b97b9505 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import Dataset, DataLoader import torchtext -from tqdm import trange +from tqdm import tqdm, trange import config as cfg @@ -17,6 +17,7 @@ from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset from utils.nn_helpers import create_noise, number_of_parameters +from utils.helpers import create_oracle from utils.create_embeddings import EmbeddingsTrainer, load_embedding from models.FixemGAN_G import Generator from models.FixemGAN_D import Discriminator @@ -30,13 +31,11 @@ # 5. fix bleu score # 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) # 7. logger -# 8. cat_fixemgan -# 9. oracle -# 10. cat_oracle # 11. make run_fixem clean +# 10. cat_oracle ? # afterwards: -# chack target real/fake to be right (Uniform or const) +# check target real/fake to be right (Uniform or const) # random data portion generator? @@ -45,13 +44,19 @@ def __init__(self, opt): super(FixemGANInstructor, self).__init__(opt) # check if embeddings already exist for current oracle + if cfg.oracle_pretrain: + if not os.path.exists(cfg.oracle_state_dict_path): + create_oracle() + self.oracle.load_state_dict( + torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + + if cfg.CUDA: + self.oracle = self.oracle.cuda() + if not os.path.exists(cfg.pretrain_embedding_path): # train embedding on available dataset or oracle print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") print("Will train new one, it may take a while...") - giant_samples = self.oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) - file_name - with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: for sample in tqdm(giant_samples): f.write(" ".join(str(int(idx)) for idx in sample)) @@ -60,168 +65,5 @@ def __init__(self, opt): sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() - w2v = load_embedding(cfg.pretrain_embedding_path) - - if cfg.run_model == 'fixemgan': - labels, train_data = zip(*[(0, line) for line in text_file_iterator(cfg.train_data)]) - - if cfg.run_model == 'cat_fixemgan': - labels, train_data = zip( - *chain( - *[[(i, line) for line in text_file_iterator(cfg.cat_train_data.format(i))] - for i in range(cfg.k_label)] - ) - ) - - self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) - - self.discriminator = Discriminator(cfg.discriminator_complexity) - print( - "discriminator total tranable parameters:", - number_of_parameters(self.discriminator.parameters()) - ) - self.generator = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) - print( - "generator total tranable parameters:", - number_of_parameters(self.generator.parameters()) - ) - - if cfg.CUDA: - self.discriminator = self.discriminator.cuda() - self.generator = self.generator.cuda() - - self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) - self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) - - self.all_metrics = [self.bleu, self.self_bleu] - - def generator_train_one_batch(self): - self.generator.optimizer.zero_grad() - noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) - if cfg.CUDA: - noise = tuple(tt.cuda() for tt in noise) - fakes = self.generator(*noise) - - real_fake_predicts, label_predicts = self.discriminator(fakes) - loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) - - loss.backward() - self.generator.optimizer.step() - - generator_acc = float( - np.array(real_fake_predicts.detach().cpu().numpy() > 0.5, dtype=int).mean() - ) - return generator_acc - - def discriminator_train_one_batch(self, real_vector, labels): - # important to have equal batch size for fake and real vectors - this_batch_size = real_vector.shape[0] - - # create input - noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) - if cfg.CUDA: - noise = tuple(tt.cuda() for tt in noise) - fake = self.generator(*noise).detach() - text_input_vectors = torch.cat((real_vector, fake)) - - # optmizer step - self.discriminator.optimizer.zero_grad() - real_fake_predicts, label_predicts = self.discriminator(text_input_vectors) - loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) - loss.backward() - self.discriminator.optimizer.step() - - real_fake_predicts = real_fake_predicts.clone().detach() - real_fake_predicts = real_fake_predicts.chunk(2) #splitting to realand fake parks - - discriminator_acc = float( - torch.cat(( - real_fake_predicts[0] > 0.5, - real_fake_predicts[1] < 0.5 - )).mean(dtype=float) - ) - return discriminator_acc - - - def _run(self): - for i in trange(cfg.max_epochs): - for labels, text_vector in self.train_data_supplier: - if cfg.CUDA: - labels, text_vector = labels.cuda(), text_vector.cuda() - discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) - - generator_acc = 1 - 2 * (discriminator_acc - 0.5) - # run the generator until generator acc not get high enought - while self.one_more_batch_for_generator(generator_acc): - generator_acc = self.generator_train_one_batch() - - - samples = self.generator.sample(20, 20) - for sample in samples: - print(sample) - - if (i + 1) % 10 == 0: - if cfg.run_model == 'fixemgan': - scores = self.cal_metrics(fmt_str=True) - if cfg.run_model == 'cat_fixemgan': - scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - print('epoch:', i, scores) - - - def one_more_batch_for_generator( - self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 - ): - generator_acc = min(leave_in_generator_max, generator_acc) - generator_acc = max(leave_in_generator_min, generator_acc) - if random.random() > generator_acc: - return True - return False - - - def cal_metrics_with_label(self, label_i, fmt_str=False): - assert type(label_i) == int, 'missing label' - with torch.no_grad(): - # Prepare data for evaluation - # eval_samples = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - # gen_data = GenDataIter(eval_samples) - # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) - gen_tokens_s = self.generator.sample(200, 200, label_i=label_i) - # clas_data = CatClasDataIter([eval_samples], label_i) - - # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) - # self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) - # self.nll_div.reset(self.gen, gen_data.loader, label_i) - self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - # self.clas_acc.reset(self.clas, clas_data.loader) - # self.ppl.reset(gen_tokens) - - if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - - return [metric.get_score() for metric in self.all_metrics] - - - def cal_metrics(self, fmt_str=False): - """ - Calculate metrics - :param fmt_str: if return format string for logging - """ - with torch.no_grad(): - # Prepare data for evaluation - gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) - gen_tokens_s = self.generator.sample(200, 200) - - # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) - # self.nll_gen.reset(self.gen, self.train_data.loader) - # self.nll_div.reset(self.gen, gen_data.loader) - self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - # self.ppl.reset(gen_tokens) - - if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + # Metrics + self.all_metrics = [self.nll_oracle] diff --git a/run/run_fixem.py b/run/run_fixem.py index c4efc4d1..a951e91c 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -44,7 +44,7 @@ tips = '{} experiments' # ===Oracle or Real=== -if_real_data = [int(True), int(True), int(True), int(True), int(True)] +if_real_data = [int(False), int(True), int(True), int(True), int(True)] dataset = ['oracle', 'mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] w2v_embedding_size = [100, 512, 100, 100, 100, 100] w2v_window = 5 diff --git a/utils/data_utils.py b/utils/data_utils.py index 67b4bda3..7cea33a1 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -4,7 +4,7 @@ # @FileName : data_utils.py # @Time : Created at 2019-03-16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from time import strftime, localtime diff --git a/utils/helpers.py b/utils/helpers.py index a7abfe10..264007b4 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -78,6 +78,20 @@ def create_oracle(): torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num // 2)) + #giant for W2V + giant_samples = self.oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) + with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: + for sample in tqdm(giant_samples): + f.write(" ".join(str(int(idx)) for idx in sample)) + f.write("\n") + + # moderate for training for W2V + train_samples = self.oracle.sample(cfg.train_samples_num, 4 * cfg.batch_size) + with open(cfg.train_data, 'w') as f: + for sample in tqdm(train_samples): + f.write(" ".join(str(int(idx)) for idx in sample)) + f.write("\n") + oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) From bde02b3c1a37ceb77a79073330f8bdde085dd09d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 30 Jan 2023 11:59:30 +0000 Subject: [PATCH 15/81] oracle fuxem implementation WIP --- config.py | 4 +- instructor/oracle_data/fixem_instructor.py | 34 +++++++------- instructor/real_data/fixem_instructor.py | 31 ++++--------- instructor/real_data/instructor.py | 7 +-- main.py | 1 + run/run_fixem.py | 52 +++------------------- utils/helpers.py | 5 ++- 7 files changed, 39 insertions(+), 95 deletions(-) diff --git a/config.py b/config.py index 5d5cbd98..209a4606 100644 --- a/config.py +++ b/config.py @@ -46,6 +46,7 @@ noise_size = 1000 max_epochs = 20 target_len = 40 +oracle_train_samples_num = 100_000 # ===Embedding=== w2v_embedding_size = 100 @@ -220,7 +221,7 @@ def init_param(opt): multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num + generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, oracle_train_samples_num if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -253,6 +254,7 @@ def init_param(opt): noise_size = opt.noise_size max_epochs = opt.max_epochs target_len = opt.target_len + oracle_train_samples_num = opt.oracle_train_samples_num samples_num = opt.samples_num vocab_size = opt.vocab_size diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index b97b9505..d58149ce 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -12,24 +12,23 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from instructor.real_data.fixem_instructor import FixemGANInstructor +from instructor.real_data.fixem_instructor import FixemGANInstructor as RealDataFixemGANInstructor from utils.gan_loss import GANLoss from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset from utils.nn_helpers import create_noise, number_of_parameters from utils.helpers import create_oracle +from metrics.nll import NLL from utils.create_embeddings import EmbeddingsTrainer, load_embedding +from models.Oracle import Oracle from models.FixemGAN_G import Generator from models.FixemGAN_D import Discriminator # TO DO: -# 1. test oracle -# 1. train embedding if not exists (if oracle, then always retrain )) -# 3. train epochs and each 10 epochs print metrics -# 4. save? or save each 10 epochs # 5. fix bleu score # 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) +# 4. save? or save each 10 epochs # 7. logger # 11. make run_fixem clean # 10. cat_oracle ? @@ -39,11 +38,9 @@ # random data portion generator? -class FixemGANInstructor(BasicInstructor, FixemGANInstructor): +class FixemGANInstructor(RealDataFixemGANInstructor, BasicInstructor): def __init__(self, opt): - super(FixemGANInstructor, self).__init__(opt) - # check if embeddings already exist for current oracle - + self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len,cfg.padding_idx, gpu=cfg.CUDA) if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() @@ -53,17 +50,16 @@ def __init__(self, opt): if cfg.CUDA: self.oracle = self.oracle.cuda() - if not os.path.exists(cfg.pretrain_embedding_path): - # train embedding on available dataset or oracle - print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") - print("Will train new one, it may take a while...") - with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: - for sample in tqdm(giant_samples): - f.write(" ".join(str(int(idx)) for idx in sample)) - f.write("\n") + super().__init__(opt) - sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] - EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() # Metrics + self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) self.all_metrics = [self.nll_oracle] + + def build_embedding(self): + # train embedding on available dataset or oracle + print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + print("Will train new one, it may take a while...") + sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 81cb57d3..aeec076b 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -21,34 +21,13 @@ from models.FixemGAN_D import Discriminator -# TO DO: -# 1. test oracle -# 1. train embedding if not exists (if oracle, then always retrain )) -# 3. train epochs and each 10 epochs print metrics -# 4. save? or save each 10 epochs -# 5. fix bleu score -# 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) -# 7. logger -# 8. cat_fixemgan -# 9. oracle -# 10. cat_oracle -# 11. make run_fixem clean - -# afterwards: -# chack target real/fake to be right (Uniform or const) -# random data portion generator? - - class FixemGANInstructor(BasicInstructor): def __init__(self, opt): super(FixemGANInstructor, self).__init__(opt) # check if embeddings already exist if not os.path.exists(cfg.pretrain_embedding_path): - # train embedding on available dataset or oracle - print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") - print("Will train new one, it may take a while...") - sources = list(Path(cfg.texts_pile).glob('*.txt')) - EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() + # train embedding on available datasets + self.build_embedding() w2v = load_embedding(cfg.pretrain_embedding_path) @@ -85,6 +64,12 @@ def __init__(self, opt): self.all_metrics = [self.bleu, self.self_bleu] + def build_embedding(self): + print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + print("Will train new one, it may take a while...") + sources = list(Path(cfg.texts_pile).glob('*.txt')) + EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() + def generator_train_one_batch(self): self.generator.optimizer.zero_grad() noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 7b787215..90ca5ba4 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -34,7 +34,8 @@ def __init__(self, opt): self.clas = None # load dictionary - self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) + if cfg.if_real_data: + self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) # Dataloader try: @@ -69,8 +70,8 @@ def __init__(self, opt): self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) - self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ppl] + # self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) + self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu]#, self.ppl] def _run(self): print('Nothing to run in Basic Instructor!') diff --git a/main.py b/main.py index 9f745194..b515da52 100644 --- a/main.py +++ b/main.py @@ -47,6 +47,7 @@ def program_config(parser): parser.add_argument('--noise_size', default=cfg.noise_size, type=int) parser.add_argument('--max_epochs', default=cfg.max_epochs, type=int) parser.add_argument('--target_len', default=cfg.target_len, type=int) + parser.add_argument('--oracle_train_samples_num', default=cfg.oracle_train_samples_num, type=int) # Basic Train parser.add_argument('--samples_num', default=cfg.samples_num, type=int) diff --git a/run/run_fixem.py b/run/run_fixem.py index a951e91c..2e7504a5 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -31,7 +31,6 @@ scriptname = 'main.py' # ===Program=== -# CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'cat_fixemgan', 'cat_fixemgan'] @@ -45,8 +44,8 @@ # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [100, 512, 100, 100, 100, 100] +dataset = ['mr20', 'mr15', 'oracle', 'amazon_app_book', 'image_coco', 'emnlp_news'] +w2v_embedding_size = [512, 100, 512, 100, 100, 100] w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 @@ -54,12 +53,8 @@ vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== -n_parent = 1 loss_type = 'fixem' -mu_type = 'ragan rsgan' -eval_type = 'Ra' -temp_adpt = 'exp' -d_out_mean = int(False) +oracle_train_samples_num = 100_000 # === Basic Param === data_shuffle = int(False) @@ -68,27 +63,12 @@ dis_init = 'uniform' samples_num = 10000 batch_size = 64 -target_len = [16, 40, 20, 16, 52, 36] -gen_lr = 0.01 -gen_adv_lr = 1e-4 -dis_lr = 1e-4 -pre_log_step = 10 -adv_log_step = 20 +target_len = [20, 40, 20, 16, 52, 36] # ===Generator=== -ADV_g_step = 1 -gen_embed_dim = 32 -gen_hidden_dim = 32 -mem_slots = 1 -num_heads = 2 -head_size = [512, 512, 512, 256, 256, 256] generator_complexity = [768, 512, 512, 512, 512, 512] # ===Discriminator=== -ADV_d_step = 3 -dis_embed_dim = 64 -dis_hidden_dim = 64 -num_rep = 64 discriminator_complexity = [512, 512, 512, 512, 512] # ===Metrics=== @@ -133,34 +113,12 @@ '--batches_per_epoch', batches_per_epoch, '--noise_size', noise_size, '--target_len', target_len[job_id], - - # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--gen_lr', gen_lr, - '--gen_adv_lr', gen_adv_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, + '--oracle_train_samples_num', oracle_train_samples_num, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--mem_slots', mem_slots, - '--num_heads', num_heads, - '--head_size', head_size[job_id], '--generator_complexity', generator_complexity[job_id], # Discriminator - '--adv_d_step', ADV_d_step, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - '--num_rep', num_rep, '--discriminator_complexity', discriminator_complexity[job_id], # Metrics diff --git a/utils/helpers.py b/utils/helpers.py index 264007b4..bd681bb7 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn as nn +from tqdm import tqdm from metrics.nll import NLL from utils.data_loader import GenDataIter @@ -79,14 +80,14 @@ def create_oracle(): cfg.oracle_samples_path.format(cfg.samples_num // 2)) #giant for W2V - giant_samples = self.oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) + giant_samples = oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: for sample in tqdm(giant_samples): f.write(" ".join(str(int(idx)) for idx in sample)) f.write("\n") # moderate for training for W2V - train_samples = self.oracle.sample(cfg.train_samples_num, 4 * cfg.batch_size) + train_samples = oracle.sample(cfg.oracle_train_samples_num, 4 * cfg.batch_size) with open(cfg.train_data, 'w') as f: for sample in tqdm(train_samples): f.write(" ".join(str(int(idx)) for idx in sample)) From e5c80020fd293d7478b5aaee260df2fd836500a4 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 31 Jan 2023 15:18:16 +0000 Subject: [PATCH 16/81] WIP --- run/run_fixem.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/run/run_fixem.py b/run/run_fixem.py index 2e7504a5..a73ce4ef 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -26,7 +26,7 @@ print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) # Executables -executable = 'python' +executable = 'python3' rootdir = '../' scriptname = 'main.py' @@ -101,14 +101,8 @@ '--w2v_workers', w2v_workers, '--w2v_samples_num', w2v_samples_num, - # CatGAN Param - '--n_parent', n_parent, + # FixemGAN Param '--loss_type', loss_type, - '--mu_type', mu_type, - '--eval_type', eval_type, - '--temp_adpt', temp_adpt, - '--d_out_mean', d_out_mean, - '--max_epochs', max_epochs, '--batches_per_epoch', batches_per_epoch, '--noise_size', noise_size, From 1d0dc9e2ba07a721cc327d909d3817d00db4eb67 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 1 Feb 2023 14:55:57 +0000 Subject: [PATCH 17/81] metrics WIP --- instructor/oracle_data/fixem_instructor.py | 4 ++++ run/run_fixem.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index d58149ce..dfa44111 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -38,6 +38,10 @@ # random data portion generator? +# half of page idea explained +# current status +# plans for future, implementation tweaks + class FixemGANInstructor(RealDataFixemGANInstructor, BasicInstructor): def __init__(self, opt): self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len,cfg.padding_idx, gpu=cfg.CUDA) diff --git a/run/run_fixem.py b/run/run_fixem.py index a73ce4ef..73758654 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -26,7 +26,7 @@ print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) # Executables -executable = 'python3' +executable = 'python' rootdir = '../' scriptname = 'main.py' From b629ebc13c0c105551c5ce677d2a35a0f33613e3 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 1 Feb 2023 15:35:30 +0000 Subject: [PATCH 18/81] WIP --- metrics/bleu.py | 5 +++-- run/run_fixem.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/metrics/bleu.py b/metrics/bleu.py index 3cc7d5d4..d5330274 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -4,7 +4,7 @@ # @FileName : bleu.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from multiprocessing import Pool @@ -28,7 +28,7 @@ def __init__(self, name=None, test_text=None, real_text=None, gram=3, portion=1, self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 self.reference = None self.is_first = True - self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset + self.portion = 0.01#portion # how many portions to use in the evaluation, default to use the whole test dataset def get_score(self, is_fast=True, given_gram=None): """ @@ -81,6 +81,7 @@ def get_bleu(self, given_gram=None): @staticmethod def cal_bleu(reference, hypothesis, weight): + print(reference, hypothesis) return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, smoothing_function=SmoothingFunction().method1) diff --git a/run/run_fixem.py b/run/run_fixem.py index 73758654..7ad262f3 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -43,13 +43,13 @@ tips = '{} experiments' # ===Oracle or Real=== -if_real_data = [int(False), int(True), int(True), int(True), int(True)] +if_real_data = [int(True), int(True), int(True), int(True), int(True)] dataset = ['mr20', 'mr15', 'oracle', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [512, 100, 512, 100, 100, 100] +w2v_embedding_size = [100, 100, 512, 100, 100, 100] w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 -w2v_samples_num = 5_000_000 +w2v_samples_num = 100_000 vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== From 0c756135ac5779ce078329fb75a9f2a03cfca200 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 1 Feb 2023 17:03:02 +0000 Subject: [PATCH 19/81] WIP --- config.py | 2 +- instructor/real_data/fixem_instructor.py | 14 ++++++++------ metrics/bleu.py | 1 - run/run_fixem.py | 1 - 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/config.py b/config.py index 209a4606..0abbb69b 100644 --- a/config.py +++ b/config.py @@ -74,7 +74,7 @@ temperature = 1 # ===Basic Train=== -samples_num = 10000 # 10000, mr15: 2000, +samples_num = 1000 # 10000, mr15: 2000, MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 inter_epoch = 15 # LeakGAN-10 diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index aeec076b..b891db0a 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -135,12 +135,12 @@ def _run(self): for sample in samples: print(sample) - if (i + 1) % 10 == 0: - if cfg.run_model == 'fixemgan': - scores = self.cal_metrics(fmt_str=True) - if cfg.run_model == 'cat_fixemgan': - scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - print('epoch:', i, scores) + # if (i + 1) % 10 == 0: + if cfg.run_model == 'fixemgan': + scores = self.cal_metrics(fmt_str=True) + if cfg.run_model == 'cat_fixemgan': + scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) + print('epoch:', i, scores) def one_more_batch_for_generator( @@ -161,8 +161,10 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): # gen_data = GenDataIter(eval_samples) # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) gen_tokens = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens = [sample.split() for sample in gen_tokens] # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) gen_tokens_s = self.generator.sample(200, 200, label_i=label_i) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] # clas_data = CatClasDataIter([eval_samples], label_i) # Reset metrics diff --git a/metrics/bleu.py b/metrics/bleu.py index d5330274..08eebd98 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -81,7 +81,6 @@ def get_bleu(self, given_gram=None): @staticmethod def cal_bleu(reference, hypothesis, weight): - print(reference, hypothesis) return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, smoothing_function=SmoothingFunction().method1) diff --git a/run/run_fixem.py b/run/run_fixem.py index 7ad262f3..1c2aaa21 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -61,7 +61,6 @@ model_type = 'fixem' gen_init = 'truncated_normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 target_len = [20, 40, 20, 16, 52, 36] From 9e4eb036a71e25342f44dd6127953e76d17fd4d1 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 2 Feb 2023 16:21:25 +0000 Subject: [PATCH 20/81] new metrics WIP --- config.py | 2 + instructor/oracle_data/fixem_instructor.py | 5 +-- instructor/real_data/fixem_instructor.py | 6 ++- instructor/real_data/instructor.py | 10 ++++- metrics/bleu.py | 2 +- metrics/gpt_nll.py | 50 ++++++++++++++++++++++ metrics/ioc.py | 38 ++++++++++++++++ utils/text_process.py | 2 +- 8 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 metrics/gpt_nll.py create mode 100644 metrics/ioc.py diff --git a/config.py b/config.py index 0abbb69b..dfd5c786 100644 --- a/config.py +++ b/config.py @@ -103,6 +103,8 @@ use_nll_div = True use_bleu = True use_self_bleu = False +use_ioc = True +use_gpt_nll = True use_clas_acc = True use_ppl = False diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index dfa44111..2ff976f4 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -26,18 +26,17 @@ # TO DO: -# 5. fix bleu score # 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) # 4. save? or save each 10 epochs # 7. logger # 11. make run_fixem clean -# 10. cat_oracle ? +# 10. cat_oracle +# 12. class accuracy # afterwards: # check target real/fake to be right (Uniform or const) # random data portion generator? - # half of page idea explained # current status # plans for future, implementation tweaks diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index b891db0a..d61cb2a2 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -62,7 +62,7 @@ def __init__(self, opt): self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) - self.all_metrics = [self.bleu, self.self_bleu] + self.all_metrics = [self.bleu, self.self_bleu, self.ioc, self.gpt_nll] def build_embedding(self): print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") @@ -174,6 +174,8 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) # self.clas_acc.reset(self.clas, clas_data.loader) # self.ppl.reset(gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.gpt_nll.reset(test_text=gen_tokens) if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) @@ -189,7 +191,9 @@ def cal_metrics(self, fmt_str=False): with torch.no_grad(): # Prepare data for evaluation gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens = [sample.split() for sample in gen_tokens] gen_tokens_s = self.generator.sample(200, 200) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 90ca5ba4..0f1dd482 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -14,6 +14,8 @@ import config as cfg from metrics.bleu import BLEU from metrics.clas_acc import ACC +from metrics.ioc import IOC +from metrics.gpt_nll import GPTNLL from metrics.nll import NLL from metrics.ppl import PPL from utils.cat_data_loader import CatClasDataIter @@ -70,8 +72,10 @@ def __init__(self, opt): self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) + self.ioc = IOC(if_use=cfg.use_ioc, real_text=self.test_data.tokens) + self.gpt_nll = GPTNLL(if_use=cfg.use_gpt_nll, real_text=self.test_data.tokens) # self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu]#, self.ppl] + self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.gpt_nll]#, self.ppl] def _run(self): print('Nothing to run in Basic Instructor!') @@ -217,7 +221,9 @@ def cal_metrics(self, fmt_str=False): self.nll_gen.reset(self.gen, self.train_data.loader) self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - self.ppl.reset(gen_tokens) + self.ppl.reset(test_text=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.gpt_nll.reset(test_text=gen_tokens) if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) diff --git a/metrics/bleu.py b/metrics/bleu.py index 08eebd98..6c9ff2eb 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -28,7 +28,7 @@ def __init__(self, name=None, test_text=None, real_text=None, gram=3, portion=1, self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 self.reference = None self.is_first = True - self.portion = 0.01#portion # how many portions to use in the evaluation, default to use the whole test dataset + self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset def get_score(self, is_fast=True, given_gram=None): """ diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py new file mode 100644 index 00000000..ca809786 --- /dev/null +++ b/metrics/gpt_nll.py @@ -0,0 +1,50 @@ +from collections import Counter +from itertools import chain +import os +import random + +import torch.nn.functional as F +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +from metrics.basic import Metrics + + +class GPTNLL(Metrics): + def __init__(self, name=None, test_text=None, real_text=None, if_use=True): + super(GPTNLL, self).__init__('GPT2 as oracle') + + self.if_use = if_use + self.test_text = test_text + self.dataset_nll = 0 + + self.NLLloss = torch.nn.NLLLoss() + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.model = GPT2LMHeadModel.from_pretrained("gpt2") + + self.real_text_nll = self.get_ioc(real_text) if real_text else None + + def get_score(self): + """Get gpt2 NLL score.""" + if not self.if_use: + return 0 + return self.get_NLL(self.test_text) - self.dataset_nll + + def reset(self, test_text=None, real_text=None): + self.test_text = test_text if test_text else self.test_text + self.real_text_nll = self.get_NLL(real_text) if real_text else self.real_text_nll + + def get_NLL(self, messages, baseline=0): + if type(messages[0]) == list: #we received list of tokens + messages = [' '.join(msg) for msg in messages] + + all_logits = [] + for message in messages: + message = "<|endoftext|>" + message + "<|endoftext|>" + inputs = self.tokenizer(message, return_tensors="pt") + logits = self.model(**inputs)[0][0] + logits = F.log_softmax(logits) + # calculating NLL loss on token appearing on it's position + all_logits.append( + self.NLLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() + ) + return np.mean(all_logits) diff --git a/metrics/ioc.py b/metrics/ioc.py new file mode 100644 index 00000000..6f8c947c --- /dev/null +++ b/metrics/ioc.py @@ -0,0 +1,38 @@ +from collections import Counter +from itertools import chain + +import nltk +import os +import random + +from metrics.basic import Metrics + + +class IOC(Metrics): + def __init__(self, name=None, test_text=None, real_text=None, if_use=True): + super(IOC, self).__init__('Index of Coincedense') + + self.if_use = if_use + self.test_text = test_text + self.real_text_ioc = self.get_ioc(real_text) if real_text else None + self.reference = None + self.is_first = True + self.portion = 0.01#portion # how many portions to use in the evaluation, default to use the whole test dataset + + def get_score(self): + """Get IOC score.""" + if not self.if_use: + return 0 + return self.get_ioc(self.test_text) / self.real_text_ioc + + def reset(self, test_text=None, real_text=None): + self.test_text = test_text if test_text else self.test_text + self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc + + def get_ioc(self, list_tokens): + """Index Of Coincedense: probability of 2 random tokens in text to equal.""" + tokens = list(chain(*list_tokens)) + counts = Counter(tokens) + total = sum(ni * (ni - 1) for ni in counts.values()) + N = len(tokens) + return total / N / (N - 1) diff --git a/utils/text_process.py b/utils/text_process.py index 97e4c704..3a4217fc 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -19,7 +19,7 @@ def text_file_iterator(file): with open(file) as raw: for line in raw.readlines(): - yield line.strip().split() + yield line.strip('\n').split() def get_tokenlized(file): From 4a1817c6669b0b08e34d5765f214184dbeda6f5a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Feb 2023 12:13:58 +0000 Subject: [PATCH 21/81] fin text preprocessing, metrics WIP --- metrics/gpt_nll.py | 9 +++++---- utils/text_process.py | 35 +++++++++++++++++------------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index ca809786..f01517a2 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -3,6 +3,7 @@ import os import random +import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel, GPT2Tokenizer @@ -21,7 +22,7 @@ def __init__(self, name=None, test_text=None, real_text=None, if_use=True): self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.model = GPT2LMHeadModel.from_pretrained("gpt2") - self.real_text_nll = self.get_ioc(real_text) if real_text else None + self.real_text_nll = self.get_NLL(real_text) if real_text else None def get_score(self): """Get gpt2 NLL score.""" @@ -39,12 +40,12 @@ def get_NLL(self, messages, baseline=0): all_logits = [] for message in messages: - message = "<|endoftext|>" + message + "<|endoftext|>" + message = self.tokenizer.eos_token + message + self.tokenizer.eos_token inputs = self.tokenizer(message, return_tensors="pt") logits = self.model(**inputs)[0][0] - logits = F.log_softmax(logits) + logits = F.log_softmax(logits, dim=1) # calculating NLL loss on token appearing on it's position all_logits.append( - self.NLLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() + self.NLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() ) return np.mean(all_logits) diff --git a/utils/text_process.py b/utils/text_process.py index 3a4217fc..80090173 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -7,7 +7,6 @@ # @Description : # Copyrights (C) 2018. All Rights Reserved. -import nltk import numpy as np import os import torch @@ -27,7 +26,7 @@ def get_tokenlized(file): tokenlized = list() with open(file) as raw: for text in raw: - text = nltk.word_tokenize(text.lower()) + text = text.strip('\n').lower().split() tokenlized.append(text) return tokenlized @@ -375,21 +374,6 @@ def vectorize_sentence(tokens, w2v, target_len: int = 52, embedding_size: int = vectorized = vectorized.T # required for pytorch return vectorized -import os -import nltk -nltk.download('punkt') -from tqdm.notebook import tqdm -from pathlib import Path - - -def tokenize_and_save(source, path, filename): - with open(Path(path) / filename, 'w') as f: - for _, line in tqdm(source, desc=filename): - line = line.strip().lower() - line = ' '.join(nltk.tokenize.word_tokenize(line)) - f.write(line) - f.write('\n') - if __name__ == '__main__': os.chdir('../') @@ -399,6 +383,22 @@ def tokenize_and_save(source, path, filename): # dataset preprocess and saving import torchtext + import os + import nltk + nltk.download('punkt') + from tqdm.notebook import tqdm + from pathlib import Path + + def tokenize_and_save(source, path, filename): + with open(Path(path) / filename, 'w') as f: + for _, line in tqdm(source, desc=filename): + line = line.strip().lower() + line = ' '.join(nltk.tokenize.word_tokenize(line)) + line = ' '.join(line.split('\n')) + line = ' '.join(line.split('\\n')) + line = ' '.join(line.split('\\')) + f.write(line) + f.write('\n') AGNEWS_train, AGNEWS_test = torchtext.datasets.AG_NEWS( root="./data", split=("train", "test") @@ -406,7 +406,6 @@ def tokenize_and_save(source, path, filename): DBpedia_train, DBpedia_test = torchtext.datasets.DBpedia( root="./data", split=("train", "test") ) - WikiText103_train, WikiText103_valid, WikiText103_test = torchtext.datasets.WikiText103( root="./data", split=("train", "valid", "test") ) From b8662691b69adabb947b0edd05c2d4283e15b684 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Feb 2023 16:23:39 +0000 Subject: [PATCH 22/81] metrics done --- instructor/real_data/fixem_instructor.py | 2 ++ metrics/gpt_nll.py | 10 ++++++---- metrics/ioc.py | 1 + run/run_fixem.py | 4 ++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index d61cb2a2..f3faf8f2 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -201,6 +201,8 @@ def cal_metrics(self, fmt_str=False): # self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) # self.ppl.reset(gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.gpt_nll.reset(test_text=gen_tokens) if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index f01517a2..f1dbe0a1 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -3,6 +3,7 @@ import os import random +import numpy as np import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel, GPT2Tokenizer @@ -16,19 +17,20 @@ def __init__(self, name=None, test_text=None, real_text=None, if_use=True): self.if_use = if_use self.test_text = test_text - self.dataset_nll = 0 self.NLLloss = torch.nn.NLLLoss() self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.model = GPT2LMHeadModel.from_pretrained("gpt2") - - self.real_text_nll = self.get_NLL(real_text) if real_text else None + print('Calculating dataset NLL') + self.real_text_nll = self.get_NLL(random.sample(real_text, 500)) if real_text else None + print(f'dataset NLL based on GPT2 is {self.real_text_nll}') + print('GPT2 as oracle metric will be calculated relative to this value') def get_score(self): """Get gpt2 NLL score.""" if not self.if_use: return 0 - return self.get_NLL(self.test_text) - self.dataset_nll + return self.get_NLL(self.test_text) - self.real_text_nll def reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text diff --git a/metrics/ioc.py b/metrics/ioc.py index 6f8c947c..4721eb11 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -15,6 +15,7 @@ def __init__(self, name=None, test_text=None, real_text=None, if_use=True): self.if_use = if_use self.test_text = test_text self.real_text_ioc = self.get_ioc(real_text) if real_text else None + print(f'Dataset Index of coincedense: {self.real_text_ioc}') self.reference = None self.is_first = True self.portion = 0.01#portion # how many portions to use in the evaluation, default to use the whole test dataset diff --git a/run/run_fixem.py b/run/run_fixem.py index 1c2aaa21..36bb65ea 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -44,8 +44,8 @@ # ===Oracle or Real=== if_real_data = [int(True), int(True), int(True), int(True), int(True)] -dataset = ['mr20', 'mr15', 'oracle', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [100, 100, 512, 100, 100, 100] +dataset = ['amazon_app_book', 'mr20', 'mr15', 'oracle', 'amazon_app_book', 'image_coco', 'emnlp_news'] +w2v_embedding_size = [128, 256, 512, 128, 128, 128] w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 From 9004cdaefcf735eb543c0fa7cd8db9249c85c906 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 5 Feb 2023 12:38:44 +0000 Subject: [PATCH 23/81] metrics WIP --- config.py | 1 - instructor/oracle_data/instructor.py | 2 +- instructor/real_data/fixem_instructor.py | 32 +++++------------------- instructor/real_data/instructor.py | 21 ++++++++++------ run/run_fixem.py | 4 +-- 5 files changed, 22 insertions(+), 38 deletions(-) diff --git a/config.py b/config.py index dfd5c786..62587d1b 100644 --- a/config.py +++ b/config.py @@ -104,7 +104,6 @@ use_bleu = True use_self_bleu = False use_ioc = True -use_gpt_nll = True use_clas_acc = True use_ppl = False diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index a0939c11..4f48f3e4 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -4,7 +4,7 @@ # @FileName : instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index f3faf8f2..c6ae8db9 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -182,29 +182,9 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): return [metric.get_score() for metric in self.all_metrics] - - def cal_metrics(self, fmt_str=False): - """ - Calculate metrics - :param fmt_str: if return format string for logging - """ - with torch.no_grad(): - # Prepare data for evaluation - gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) - gen_tokens = [sample.split() for sample in gen_tokens] - gen_tokens_s = self.generator.sample(200, 200) - gen_tokens_s = [sample.split() for sample in gen_tokens_s] - - # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) - # self.nll_gen.reset(self.gen, self.train_data.loader) - # self.nll_div.reset(self.gen, gen_data.loader) - self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - # self.ppl.reset(gen_tokens) - self.ioc.reset(test_text=gen_tokens) - self.gpt_nll.reset(test_text=gen_tokens) - - if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + def sample_for_metrics(self): + gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens = [sample.split() for sample in gen_tokens] + gen_tokens_s = self.generator.sample(200, 200) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] + return None, gen_tokens, gen_tokens_s diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 0f1dd482..48d55d31 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -38,6 +38,8 @@ def __init__(self, opt): # load dictionary if cfg.if_real_data: self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) + else: + self.word2idx_dict, self.idx2word_dict = {}, {} # Dataloader try: @@ -73,9 +75,9 @@ def __init__(self, opt): self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) self.ioc = IOC(if_use=cfg.use_ioc, real_text=self.test_data.tokens) - self.gpt_nll = GPTNLL(if_use=cfg.use_gpt_nll, real_text=self.test_data.tokens) + self.nll_oracle = GPTNLL(if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) # self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.gpt_nll]#, self.ppl] + self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] def _run(self): print('Nothing to run in Basic Instructor!') @@ -204,6 +206,13 @@ def show_config(self): self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) self.log.info(100 * '=') + def sample_for_metrics(self): + eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) + return gen_data, gen_tokens, gen_tokens_s + def cal_metrics(self, fmt_str=False): """ Calculate metrics @@ -211,11 +220,7 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) - gen_data = GenDataIter(eval_samples) - gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) - + gen_data, gen_tokens, gen_tokens_s = sample_for_metrics() # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) @@ -223,7 +228,7 @@ def cal_metrics(self, fmt_str=False): self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.ppl.reset(test_text=gen_tokens) self.ioc.reset(test_text=gen_tokens) - self.gpt_nll.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) if fmt_str: return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) diff --git a/run/run_fixem.py b/run/run_fixem.py index 36bb65ea..e0953acf 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -72,8 +72,8 @@ # ===Metrics=== use_nll_oracle = int(True) -use_nll_gen = int(True) -use_nll_div = int(True) +use_nll_gen = int(False) +use_nll_div = int(False) use_bleu = int(True) use_self_bleu = int(True) use_clas_acc = int(True) From 429712fbc56cfe3610aa667e4bad29e74a347fcb Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 5 Feb 2023 12:45:38 +0000 Subject: [PATCH 24/81] metrics WIP --- instructor/oracle_data/fixem_instructor.py | 1 - instructor/real_data/fixem_instructor.py | 32 +++++++++++----------- instructor/real_data/instructor.py | 2 +- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 2ff976f4..8238ed04 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -58,7 +58,6 @@ def __init__(self, opt): # Metrics self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) - self.all_metrics = [self.nll_oracle] def build_embedding(self): # train embedding on available dataset or oracle diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index c6ae8db9..8e965afe 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -44,20 +44,20 @@ def __init__(self, opt): self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) - self.discriminator = Discriminator(cfg.discriminator_complexity) + self.disc = Discriminator(cfg.discriminator_complexity) print( "discriminator total tranable parameters:", - number_of_parameters(self.discriminator.parameters()) + number_of_parameters(self.dis.parameters()) ) - self.generator = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) + self.gen = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) print( "generator total tranable parameters:", - number_of_parameters(self.generator.parameters()) + number_of_parameters(self.gen.parameters()) ) if cfg.CUDA: - self.discriminator = self.discriminator.cuda() - self.generator = self.generator.cuda() + self.dis = self.dis.cuda() + self.gen = self.gen.cuda() self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) @@ -77,11 +77,11 @@ def generator_train_one_batch(self): noise = tuple(tt.cuda() for tt in noise) fakes = self.generator(*noise) - real_fake_predicts, label_predicts = self.discriminator(fakes) + real_fake_predicts, label_predicts = self.dis(fakes) loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) loss.backward() - self.generator.optimizer.step() + self.gen.optimizer.step() generator_acc = float( np.array(real_fake_predicts.detach().cpu().numpy() > 0.5, dtype=int).mean() @@ -100,11 +100,11 @@ def discriminator_train_one_batch(self, real_vector, labels): text_input_vectors = torch.cat((real_vector, fake)) # optmizer step - self.discriminator.optimizer.zero_grad() - real_fake_predicts, label_predicts = self.discriminator(text_input_vectors) + self.dis.optimizer.zero_grad() + real_fake_predicts, label_predicts = self.dis(text_input_vectors) loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) loss.backward() - self.discriminator.optimizer.step() + self.dis.optimizer.step() real_fake_predicts = real_fake_predicts.clone().detach() real_fake_predicts = real_fake_predicts.chunk(2) #splitting to realand fake parks @@ -131,7 +131,7 @@ def _run(self): generator_acc = self.generator_train_one_batch() - samples = self.generator.sample(20, 20) + samples = self.gene.sample(20, 20) for sample in samples: print(sample) @@ -160,10 +160,10 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): # eval_samples = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) # gen_data = GenDataIter(eval_samples) # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_tokens = [sample.split() for sample in gen_tokens] # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) - gen_tokens_s = self.generator.sample(200, 200, label_i=label_i) + gen_tokens_s = self.gen.sample(200, 200, label_i=label_i) gen_tokens_s = [sample.split() for sample in gen_tokens_s] # clas_data = CatClasDataIter([eval_samples], label_i) @@ -183,8 +183,8 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): return [metric.get_score() for metric in self.all_metrics] def sample_for_metrics(self): - gen_tokens = self.generator.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_tokens = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_tokens = [sample.split() for sample in gen_tokens] - gen_tokens_s = self.generator.sample(200, 200) + gen_tokens_s = self.gen.sample(200, 200) gen_tokens_s = [sample.split() for sample in gen_tokens_s] return None, gen_tokens, gen_tokens_s diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 48d55d31..a1ba3bcb 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -76,7 +76,7 @@ def __init__(self, opt): self.clas_acc = ACC(if_use=cfg.use_clas_acc) self.ioc = IOC(if_use=cfg.use_ioc, real_text=self.test_data.tokens) self.nll_oracle = GPTNLL(if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) - # self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) + self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] def _run(self): From ac63f679a2711f3bea0fe0280d76491de4c93c34 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sun, 5 Feb 2023 16:10:04 +0000 Subject: [PATCH 25/81] consistent metrics WIP --- instructor/real_data/fixem_instructor.py | 16 +++++++--------- instructor/real_data/instructor.py | 2 +- utils/data_loader.py | 4 ++-- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 8e965afe..2c685d94 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -14,7 +14,7 @@ from instructor.real_data.instructor import BasicInstructor from utils.gan_loss import GANLoss from utils.text_process import text_file_iterator -from utils.data_loader import DataSupplier, GANDataset +from utils.data_loader import DataSupplier, GenDataIter, GANDataset from utils.nn_helpers import create_noise, number_of_parameters from utils.create_embeddings import EmbeddingsTrainer, load_embedding from models.FixemGAN_G import Generator @@ -44,7 +44,7 @@ def __init__(self, opt): self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) - self.disc = Discriminator(cfg.discriminator_complexity) + self.dis = Discriminator(cfg.discriminator_complexity) print( "discriminator total tranable parameters:", number_of_parameters(self.dis.parameters()) @@ -62,8 +62,6 @@ def __init__(self, opt): self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) - self.all_metrics = [self.bleu, self.self_bleu, self.ioc, self.gpt_nll] - def build_embedding(self): print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") print("Will train new one, it may take a while...") @@ -71,11 +69,11 @@ def build_embedding(self): EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() def generator_train_one_batch(self): - self.generator.optimizer.zero_grad() + self.gen.optimizer.zero_grad() noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) if cfg.CUDA: noise = tuple(tt.cuda() for tt in noise) - fakes = self.generator(*noise) + fakes = self.gen(*noise) real_fake_predicts, label_predicts = self.dis(fakes) loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) @@ -96,7 +94,7 @@ def discriminator_train_one_batch(self, real_vector, labels): noise = create_noise(cfg.batch_size, cfg.noise_size, cfg.k_label) if cfg.CUDA: noise = tuple(tt.cuda() for tt in noise) - fake = self.generator(*noise).detach() + fake = self.gen(*noise).detach() text_input_vectors = torch.cat((real_vector, fake)) # optmizer step @@ -131,7 +129,7 @@ def _run(self): generator_acc = self.generator_train_one_batch() - samples = self.gene.sample(20, 20) + samples = self.gen.sample(20, 20) for sample in samples: print(sample) @@ -187,4 +185,4 @@ def sample_for_metrics(self): gen_tokens = [sample.split() for sample in gen_tokens] gen_tokens_s = self.gen.sample(200, 200) gen_tokens_s = [sample.split() for sample in gen_tokens_s] - return None, gen_tokens, gen_tokens_s + return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index a1ba3bcb..06fc4a7e 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -220,7 +220,7 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - gen_data, gen_tokens, gen_tokens_s = sample_for_metrics() + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) diff --git a/utils/data_loader.py b/utils/data_loader.py index 7e7dc0ee..4f81900a 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -45,8 +45,8 @@ def __init__(self, samples, if_test_data=False, shuffle=None): shuffle=self.shuffle, drop_last=True) - self.input = self._all_data_('input') - self.target = self._all_data_('target') + # self.input = self._all_data_('input') + # self.target = self._all_data_('target') def __read_data__(self, samples): """ From 9dd6de17b14864a7e625844eaf6fe5896f0fa40e Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 09:34:39 +0000 Subject: [PATCH 26/81] metrics clean WIP --- instructor/real_data/fixem_instructor.py | 38 +++++------------------- instructor/real_data/instructor.py | 26 +++++++++------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 2c685d94..8c60c960 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -15,6 +15,7 @@ from utils.gan_loss import GANLoss from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GenDataIter, GANDataset +from utils.cat_data_loader import CatClasDataIter from utils.nn_helpers import create_noise, number_of_parameters from utils.create_embeddings import EmbeddingsTrainer, load_embedding from models.FixemGAN_G import Generator @@ -150,39 +151,16 @@ def one_more_batch_for_generator( return True return False - - def cal_metrics_with_label(self, label_i, fmt_str=False): - assert type(label_i) == int, 'missing label' - with torch.no_grad(): - # Prepare data for evaluation - # eval_samples = self.generator.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - # gen_data = GenDataIter(eval_samples) - # gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - gen_tokens = [sample.split() for sample in gen_tokens] - # gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) - gen_tokens_s = self.gen.sample(200, 200, label_i=label_i) - gen_tokens_s = [sample.split() for sample in gen_tokens_s] - # clas_data = CatClasDataIter([eval_samples], label_i) - - # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) - # self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) - # self.nll_div.reset(self.gen, gen_data.loader, label_i) - self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - # self.clas_acc.reset(self.clas, clas_data.loader) - # self.ppl.reset(gen_tokens) - self.ioc.reset(test_text=gen_tokens) - self.gpt_nll.reset(test_text=gen_tokens) - - if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - - return [metric.get_score() for metric in self.all_metrics] - def sample_for_metrics(self): gen_tokens = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_tokens = [sample.split() for sample in gen_tokens] gen_tokens_s = self.gen.sample(200, 200) gen_tokens_s = [sample.split() for sample in gen_tokens_s] return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s + + def sample_for_metrics_with_label(self, label_i): + gen_tokens = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens = [sample.split() for sample in gen_tokens] + gen_tokens_s = self.gen.sample(200, 200, label_i=label_i) + gen_tokens_s = [sample.split() for sample in gen_tokens_s] + return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s, CatClasDataIter([gen_tokens], label_i) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 06fc4a7e..f3978ff9 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -213,6 +213,14 @@ def sample_for_metrics(self): gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) return gen_data, gen_tokens, gen_tokens_s + def sample_for_metrics_with_label(self, label_i): + eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + gen_data = GenDataIter(eval_samples) + gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) + clas_data = CatClasDataIter([eval_samples], label_i) + return gen_data, gen_tokens, gen_tokens_s, clas_data + def cal_metrics(self, fmt_str=False): """ Calculate metrics @@ -226,26 +234,20 @@ def cal_metrics(self, fmt_str=False): self.nll_gen.reset(self.gen, self.train_data.loader) self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) - self.ppl.reset(test_text=gen_tokens) + self.ppl.reset(gen_tokens=gen_tokens) self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + return '\n'.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return [metric.get_score() for metric in self.all_metrics] - def cal_metrics_with_label(self, label_i): + def cal_metrics_with_label(self, label_i, fmt_str=False): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - gen_data = GenDataIter(eval_samples) - gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) - clas_data = CatClasDataIter([eval_samples], label_i) - + gen_data, gen_tokens, gen_tokens_s, clas_data = sample_for_metrics_with_label(label_i) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) @@ -254,6 +256,8 @@ def cal_metrics_with_label(self, label_i): self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) + if fmt_str: + return '\n'.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): From 1e3fcfbd32c13793310a8a6d9fc602c8f495a24e Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 10:14:24 +0000 Subject: [PATCH 27/81] seed everything --- main.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/main.py b/main.py index b515da52..e535374b 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,11 @@ # Copyrights (C) 2018. All Rights Reserved. from __future__ import print_function +import random + import argparse +import torch +import numpy as np import config as cfg from utils.text_process import load_test_dict, text_process @@ -124,6 +128,11 @@ def program_config(parser): # MAIN if __name__ == '__main__': + #seed everything + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + # Hyper Parameters parser = argparse.ArgumentParser() parser = program_config(parser) From 045675cedef64e18df70b517cf3f09acd738ac76 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 10:14:56 +0000 Subject: [PATCH 28/81] added new commands to get started in README --- README.md | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1f557a02..8dc28b37 100644 --- a/README.md +++ b/README.md @@ -67,9 +67,20 @@ To install, run `pip install -r requirements.txt`. In case of CUDA problems, con git clone https://github.com/williamSYSU/TextGAN-PyTorch.git cd TextGAN-PyTorch ``` +- Downlaod dataset and pretrained embeddings from kaggle dataset or manually: +```bash +kaggle datasets download -d salaxieb/texts-corpus-preprocessed +kaggle datasets download -d salaxieb/pretrained-embeddings + +unzip ../texts-corpus-preprocessed.zip -d dataset +mkdir dataset/testdata +mv dataset/*_test.txt dataset/testdata/ + +mkdir pretrain +unzip ../pretrained-embeddings.zip -d pretrain/real_data +``` -- For real data experiments, all datasets (`Image COCO`, `EMNLP NEWs`, `Movie Review`, `Amazon Review`) can be downloaded from [here](https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing). -- Run with a specific model +- Manually (`Image COCO`, `EMNLP NEWs`, `Movie Review`, `Amazon Review`) can be downloaded from [here](https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing). ```bash cd run @@ -86,13 +97,13 @@ python3 run_seqgan.py 0 0 For each model, the entire runing process is defined in `instructor/oracle_data/seqgan_instructor.py`. (Take SeqGAN in Synthetic data experiment for example). Some basic functions like `init_model()`and `optimize()` are defined in the base class `BasicInstructor` in `instructor.py`. If you want to add a new GAN-based text generation model, please create a new instructor under `instructor/oracle_data` and define the training process for the model. 2. **Visualization** - + Use `utils/visualization.py` to visualize the log file, including model loss and metrics scores. Custom your log files in `log_file_list`, no more than `len(color_list)`. The log filename should exclude `.txt`. - + 3. **Logging** The TextGAN-PyTorch use the `logging` module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved in `log/log_****_****.txt` and `save/**/log.txt` respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in `./save/**/models` and `./save/**/samples` per log step, where `**` depends on your hyper-parameters. - + 4. **Running Signal** You can easily control the training process with the class `Signal` (please refer to `utils/helpers.py`) based on dictionary file `run_signal.txt`. @@ -164,7 +175,7 @@ python3 run_seqgan.py 0 0 - Structure (from my understanding) ![model_relgan](assets/model_relgan.png) - + ### DPGAN - run file: [run_dpgan.py](run/run_dpgan.py) @@ -176,7 +187,7 @@ python3 run_seqgan.py 0 0 - Structure (from [DPGAN](https://arxiv.org/abs/1802.01345)) ![model_dpgan](assets/model_dpgan.png) - + ### DGSAN - run file: [run_dgsan.py](run/run_dgsan.py) @@ -221,9 +232,8 @@ python3 run_seqgan.py 0 0 ![model_catgan](assets/model_catgan.png) - + ## Licence **MIT lincense** - From 112517d03aa430542e259f8e987766ace635d43a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 10:15:10 +0000 Subject: [PATCH 29/81] clean run files --- run/run_catgan.py | 8 ++++---- run/run_fixem.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/run/run_catgan.py b/run/run_catgan.py index 8efef61e..edda3f01 100644 --- a/run/run_catgan.py +++ b/run/run_catgan.py @@ -4,7 +4,7 @@ # @FileName : run_catgan.py # @Time : Created at 2019-08-04 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys from subprocess import call @@ -34,7 +34,7 @@ # CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) -run_model = ['catgan', 'catgan', 'catgan', 'evogan', 'evogan', 'evogan'] +run_model = ['evogan', 'catgan', 'catgan', 'catgan', 'evogan', 'evogan'] k_label = 2 CUDA = int(True) ora_pretrain = int(True) @@ -46,8 +46,8 @@ tips = '{} experiments' # ===Oracle or Real=== -if_real_data = [int(False), int(True), int(True), int(False), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book', 'oracle', 'image_coco', 'emnlp_news'] +if_real_data = [int(True), int(False), int(True), int(False), int(True), int(True)] +dataset = ['amazon_app_book', 'oracle', 'mr15', 'oracle', 'image_coco', 'emnlp_news'] vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== diff --git a/run/run_fixem.py b/run/run_fixem.py index e0953acf..83a223ae 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -33,7 +33,7 @@ # ===Program=== # EvoGAN: General text generation model if_test = int(False) -run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'cat_fixemgan', 'cat_fixemgan'] +run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan'] k_label = 2 CUDA = int(True) batch_size = 32 @@ -43,14 +43,14 @@ tips = '{} experiments' # ===Oracle or Real=== -if_real_data = [int(True), int(True), int(True), int(True), int(True)] -dataset = ['amazon_app_book', 'mr20', 'mr15', 'oracle', 'amazon_app_book', 'image_coco', 'emnlp_news'] -w2v_embedding_size = [128, 256, 512, 128, 128, 128] +if_real_data = [int(True), int(True), int(True), int(True), int(True), int(True), int(True), int(False), int(False)] +dataset = ['amazon_app_book', 'mr20', 'mr20', 'mr15', 'mr15', 'image_coco', 'emnlp_news', 'oracle', 'oracle'] +w2v_embedding_size = 512 #hyperparam w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 -w2v_samples_num = 100_000 -vocab_size = [5000, 0, 0, 5000, 0, 0] +w2v_samples_num = 5_000_000 +vocab_size = 5000 # ===CatGAN Param=== loss_type = 'fixem' @@ -62,13 +62,13 @@ gen_init = 'truncated_normal' dis_init = 'uniform' batch_size = 64 -target_len = [20, 40, 20, 16, 52, 36] +target_len = [40, 20, 20, 16, 16, 16, 48, 20, 20] # ===Generator=== -generator_complexity = [768, 512, 512, 512, 512, 512] +generator_complexity = 768 #hyperparam # ===Discriminator=== -discriminator_complexity = [512, 512, 512, 512, 512] +discriminator_complexity = 512 #hyperparam # ===Metrics=== use_nll_oracle = int(True) @@ -91,10 +91,10 @@ # Oracle or Real '--if_real_data', if_real_data[job_id], '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], + '--vocab_size', vocab_size, # W2V embeddings - '--w2v_embedding_size', w2v_embedding_size[job_id], + '--w2v_embedding_size', w2v_embedding_size, '--w2v_window', w2v_window, '--w2v_min_count', w2v_min_count, '--w2v_workers', w2v_workers, @@ -109,10 +109,10 @@ '--oracle_train_samples_num', oracle_train_samples_num, # Generator - '--generator_complexity', generator_complexity[job_id], + '--generator_complexity', generator_complexity, # Discriminator - '--discriminator_complexity', discriminator_complexity[job_id], + '--discriminator_complexity', discriminator_complexity, # Metrics '--use_nll_oracle', use_nll_oracle, From 58f3f780b5ed2d1f57b28badc4be0de3a4fe0054 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 10:21:46 +0000 Subject: [PATCH 30/81] logger used properly --- instructor/oracle_data/fixem_instructor.py | 9 +++------ instructor/real_data/fixem_instructor.py | 12 ++++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 8238ed04..29d230af 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -26,16 +26,13 @@ # TO DO: -# 6. add new interested scores (IOC, NLL on GPT) (split quick metric and slow metric) # 4. save? or save each 10 epochs -# 7. logger -# 11. make run_fixem clean # 10. cat_oracle # 12. class accuracy # afterwards: # check target real/fake to be right (Uniform or const) -# random data portion generator? +# random data portion generator - data supplier sample from randomint # half of page idea explained # current status @@ -61,7 +58,7 @@ def __init__(self, opt): def build_embedding(self): # train embedding on available dataset or oracle - print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") - print("Will train new one, it may take a while...") + self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + self.log.info("Will train new one, it may take a while...") sources = [cfg.oracle_samples_path.format(cfg.w2v_samples_num)] EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 8c60c960..ad6a5c28 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -46,12 +46,12 @@ def __init__(self, opt): self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) self.dis = Discriminator(cfg.discriminator_complexity) - print( + self.log.info( "discriminator total tranable parameters:", number_of_parameters(self.dis.parameters()) ) self.gen = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) - print( + self.log.info( "generator total tranable parameters:", number_of_parameters(self.gen.parameters()) ) @@ -64,8 +64,8 @@ def __init__(self, opt): self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) def build_embedding(self): - print(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") - print("Will train new one, it may take a while...") + self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") + self.log.info("Will train new one, it may take a while...") sources = list(Path(cfg.texts_pile).glob('*.txt')) EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() @@ -132,14 +132,14 @@ def _run(self): samples = self.gen.sample(20, 20) for sample in samples: - print(sample) + self.log.info(sample) # if (i + 1) % 10 == 0: if cfg.run_model == 'fixemgan': scores = self.cal_metrics(fmt_str=True) if cfg.run_model == 'cat_fixemgan': scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - print('epoch:', i, scores) + self.log.info('epoch:', i, scores) def one_more_batch_for_generator( From fdfa125a91992d48620f890d6acf352973658d10 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 11:53:15 +0000 Subject: [PATCH 31/81] fixing logger error --- instructor/real_data/fixem_instructor.py | 14 ++++---------- utils/data_loader.py | 10 +++++++++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index ad6a5c28..f23a5686 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -46,15 +46,9 @@ def __init__(self, opt): self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) self.dis = Discriminator(cfg.discriminator_complexity) - self.log.info( - "discriminator total tranable parameters:", - number_of_parameters(self.dis.parameters()) - ) + self.log.info(f"discriminator total tranable parameters: {number_of_parameters(self.dis.parameters())}") self.gen = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) - self.log.info( - "generator total tranable parameters:", - number_of_parameters(self.gen.parameters()) - ) + self.log.info(f"generator total tranable parameters: {number_of_parameters(self.gen.parameters())}") if cfg.CUDA: self.dis = self.dis.cuda() @@ -138,8 +132,8 @@ def _run(self): if cfg.run_model == 'fixemgan': scores = self.cal_metrics(fmt_str=True) if cfg.run_model == 'cat_fixemgan': - scores = ' '.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - self.log.info('epoch:', i, scores) + scores = '\n\n'.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) + self.log.info('epoch: {i} \n {scores}') def one_more_batch_for_generator( diff --git a/utils/data_loader.py b/utils/data_loader.py index 4f81900a..9d31948e 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -141,7 +141,15 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): self.labels = torch.tensor(labels, dtype=int) - self.vectors = [vectorize_sentence(tokens, w2v, target_len=cfg.target_len, padding_token = cfg.padding_token) for tokens in tokenized] + self.vectors = [ + vectorize_sentence( + tokens, + w2v, + target_len=cfg.target_len, + padding_token = cfg.padding_token, + ) + for tokens in tqdm(tokenized, desc='vectorizing dataset') + ] self.vectors = np.stack(self.vectors, axis=0) self.vectors = torch.tensor(self.vectors, dtype=torch.float32) From c9588769036fe2166f25a2771b0a09238068809d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 11:53:52 +0000 Subject: [PATCH 32/81] decresing ram usage --- run/run_fixem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run/run_fixem.py b/run/run_fixem.py index 83a223ae..b2ae731b 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -45,7 +45,7 @@ # ===Oracle or Real=== if_real_data = [int(True), int(True), int(True), int(True), int(True), int(True), int(True), int(False), int(False)] dataset = ['amazon_app_book', 'mr20', 'mr20', 'mr15', 'mr15', 'image_coco', 'emnlp_news', 'oracle', 'oracle'] -w2v_embedding_size = 512 #hyperparam +w2v_embedding_size = 128 #low on ram #hyperparam w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 From d5bad1c30c1b2a71c6b83bd48bffcada00ab524b Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 14:48:14 +0000 Subject: [PATCH 33/81] ready for tests --- instructor/oracle_data/fixem_instructor.py | 4 ---- instructor/real_data/fixem_instructor.py | 7 +++---- run/run_fixem.py | 6 +++--- utils/data_loader.py | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 29d230af..f6c7bfca 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -34,10 +34,6 @@ # check target real/fake to be right (Uniform or const) # random data portion generator - data supplier sample from randomint -# half of page idea explained -# current status -# plans for future, implementation tweaks - class FixemGANInstructor(RealDataFixemGANInstructor, BasicInstructor): def __init__(self, opt): self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len,cfg.padding_idx, gpu=cfg.CUDA) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index f23a5686..5abe120b 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -124,16 +124,15 @@ def _run(self): generator_acc = self.generator_train_one_batch() - samples = self.gen.sample(20, 20) - for sample in samples: - self.log.info(sample) + self.log.info('\n'.join(self.gen.sample(20, 20))) # if (i + 1) % 10 == 0: if cfg.run_model == 'fixemgan': scores = self.cal_metrics(fmt_str=True) if cfg.run_model == 'cat_fixemgan': scores = '\n\n'.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - self.log.info('epoch: {i} \n {scores}') + self.log.info(f'epoch: {i}') + self.log.info(f'{scores}') def one_more_batch_for_generator( diff --git a/run/run_fixem.py b/run/run_fixem.py index b2ae731b..bc20c2ed 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -44,8 +44,8 @@ # ===Oracle or Real=== if_real_data = [int(True), int(True), int(True), int(True), int(True), int(True), int(True), int(False), int(False)] -dataset = ['amazon_app_book', 'mr20', 'mr20', 'mr15', 'mr15', 'image_coco', 'emnlp_news', 'oracle', 'oracle'] -w2v_embedding_size = 128 #low on ram #hyperparam +dataset = ['image_coco', 'mr20', 'mr20', 'mr15', 'mr15', 'amazon_app_book', 'emnlp_news', 'oracle', 'oracle'] +w2v_embedding_size = 512 #low on ram #hyperparam w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 @@ -62,7 +62,7 @@ gen_init = 'truncated_normal' dis_init = 'uniform' batch_size = 64 -target_len = [40, 20, 20, 16, 16, 16, 48, 20, 20] +target_len = [16, 20, 20, 16, 16, 40, 48, 20, 20] # architechture requires to be divisible by 4 # ===Generator=== generator_complexity = 768 #hyperparam diff --git a/utils/data_loader.py b/utils/data_loader.py index 9d31948e..73b987ae 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -11,7 +11,7 @@ import torch from torch.utils.data import Dataset, DataLoader -from tqdm import trange +from tqdm import tqdm, trange import config as cfg from utils.text_process import * From 73513c498ba51582ce7b2fe3b719e914fcb6f14d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Mon, 6 Feb 2023 20:53:13 +0000 Subject: [PATCH 34/81] type and progress bar for evogan --- instructor/real_data/evogan_instructor.py | 11 ++++++++--- metrics/ioc.py | 7 +++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/instructor/real_data/evogan_instructor.py b/instructor/real_data/evogan_instructor.py index 5fcee9e9..6e410d47 100644 --- a/instructor/real_data/evogan_instructor.py +++ b/instructor/real_data/evogan_instructor.py @@ -12,7 +12,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from tqdm import tqdm +from tqdm import tqdm, trange import config as cfg from instructor.real_data.instructor import BasicInstructor @@ -93,7 +93,7 @@ def _run(self): # # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') - progress = tqdm(range(cfg.ADV_train_epoch)) + progress = trange(cfg.ADV_train_epoch) for adv_epoch in progress: if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) @@ -126,7 +126,7 @@ def pretrain_generator(self, epochs): """ Max Likelihood Pre-training for the generator """ - for epoch in range(epochs): + for epoch in trange(epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== @@ -137,6 +137,11 @@ def pretrain_generator(self, epochs): self.log.info( '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + eval_samples = self.gen.sample(20, 20) + gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) + for sample in gen_tokens: + self.log.info(' '.join(sample)) + if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: diff --git a/metrics/ioc.py b/metrics/ioc.py index 4721eb11..5ad98aff 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -10,15 +10,14 @@ class IOC(Metrics): def __init__(self, name=None, test_text=None, real_text=None, if_use=True): - super(IOC, self).__init__('Index of Coincedense') + super(IOC, self).__init__('Index of Coincidence') self.if_use = if_use self.test_text = test_text self.real_text_ioc = self.get_ioc(real_text) if real_text else None - print(f'Dataset Index of coincedense: {self.real_text_ioc}') + print(f'Dataset Index of coincidence: {self.real_text_ioc}') self.reference = None self.is_first = True - self.portion = 0.01#portion # how many portions to use in the evaluation, default to use the whole test dataset def get_score(self): """Get IOC score.""" @@ -31,7 +30,7 @@ def reset(self, test_text=None, real_text=None): self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc def get_ioc(self, list_tokens): - """Index Of Coincedense: probability of 2 random tokens in text to equal.""" + """Index Of coincidence: probability of 2 random tokens in text to equal.""" tokens = list(chain(*list_tokens)) counts = Counter(tokens) total = sum(ni * (ni - 1) for ni in counts.values()) From 5a169793c936c8bbd91efcbd559af62597cad145 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 7 Feb 2023 18:10:37 +0000 Subject: [PATCH 35/81] vectorizing texts batch vise --- utils/data_loader.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/utils/data_loader.py b/utils/data_loader.py index 73b987ae..bbe15180 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -141,17 +141,7 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): self.labels = torch.tensor(labels, dtype=int) - self.vectors = [ - vectorize_sentence( - tokens, - w2v, - target_len=cfg.target_len, - padding_token = cfg.padding_token, - ) - for tokens in tqdm(tokenized, desc='vectorizing dataset') - ] - self.vectors = np.stack(self.vectors, axis=0) - self.vectors = torch.tensor(self.vectors, dtype=torch.float32) + self.tokenized = tokenized self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size @@ -159,6 +149,19 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) print('dataset random texts examples', [txt for txt in self.texts][:3]) + def vectorize_batch(self, tokenized): + vectors = [ + vectorize_sentence( + tokens, + w2v, + target_len=cfg.target_len, + padding_token = cfg.padding_token, + ) + for tokens in tokenized + ] + vectors = np.stack(vectors, axis=0) + vectors = torch.tensor(vectors, dtype=torch.float32) + return vectors def __iter__(self): permutation = torch.randperm(len(self)) @@ -170,11 +173,16 @@ def __iter__(self): index += self.batch_size if index > len(self): # concatenating beginning of self.vectors - yield (torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), - torch.cat((self.vectors[index - self.batch_size: index], self.vectors[:index-len(self)]))) + yield ( + torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), + torch.cat(( + self.vectorize_batch(self.vectors[index - self.batch_size: index]), + self.vectorize_batch(self.vectors[:index-len(self)]) + )) + ) index = index % len(self) else: - yield self.labels[index - self.batch_size: index], self.vectors[index - self.batch_size: index] + yield self.labels[index - self.batch_size: index], self.vectorize_batch(self.vectors[index - self.batch_size: index]) def __len__(self): From 25e88cef2a6b52b32272b7ac742aba27c6458cc1 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 08:26:40 +0000 Subject: [PATCH 36/81] typos and intel to metrics --- metrics/bleu.py | 36 ++++++++++++++++++------------------ metrics/gpt_nll.py | 3 ++- metrics/nll.py | 13 ++++++------- metrics/ppl.py | 2 +- utils/data_loader.py | 29 +++++++++++++++++++---------- 5 files changed, 46 insertions(+), 37 deletions(-) diff --git a/metrics/bleu.py b/metrics/bleu.py index 6c9ff2eb..019e841e 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -8,10 +8,12 @@ # Copyrights (C) 2018. All Rights Reserved. from multiprocessing import Pool -import nltk import os import random + +import nltk from nltk.translate.bleu_score import SmoothingFunction +from tqdm import tqdm from metrics.basic import Metrics @@ -62,22 +64,20 @@ def get_reference(self): def get_bleu(self, given_gram=None): if given_gram is not None: # for single gram - bleu = list() - reference = self.get_reference() - weight = tuple((1. / given_gram for _ in range(given_gram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - bleu.append(self.cal_bleu(reference, hypothesis, weight)) - return round(sum(bleu) / len(bleu), 3) - else: # for multiple gram - all_bleu = [] - for ngram in self.gram: - bleu = list() - reference = self.get_reference() - weight = tuple((1. / ngram for _ in range(ngram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - bleu.append(self.cal_bleu(reference, hypothesis, weight)) - all_bleu.append(round(sum(bleu) / len(bleu), 3)) - return all_bleu + return self.get_blue_for_single_gram(given_gram) + # for multiple gram + all_bleu = [] + for ngram in self.gram: + all_bleu.append(self.get_blue_for_single_gram(ngram)) + return all_bleu + + def get_blue_for_single_gram(self, ngram): + bleu = list() + reference = self.get_reference() + weight = tuple((1. / ngram for _ in range(ngram))) + for idx, hypothesis in enumerate(tqdm(self.test_text[:self.sample_size], desc=self.name)): + bleu.append(self.cal_bleu(reference, hypothesis, weight)) + return round(sum(bleu) / len(bleu), 3) @staticmethod def cal_bleu(reference, hypothesis, weight): @@ -98,7 +98,7 @@ def get_bleu_parallel(self, ngram, reference): weight = tuple((1. / ngram for _ in range(ngram))) pool = Pool(os.cpu_count()) result = list() - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): + for idx, hypothesis in enumerate(tqdm(self.test_text[:self.sample_size], desc=self.name)): result.append(pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight))) score = 0.0 cnt = 0 diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index f1dbe0a1..79973308 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel, GPT2Tokenizer +from tqdm import tqdm from metrics.basic import Metrics @@ -41,7 +42,7 @@ def get_NLL(self, messages, baseline=0): messages = [' '.join(msg) for msg in messages] all_logits = [] - for message in messages: + for message in tqdm(messages, desc=self.name): message = self.tokenizer.eos_token + message + self.tokenizer.eos_token inputs = self.tokenizer(message, return_tensors="pt") logits = self.model(**inputs)[0][0] diff --git a/metrics/nll.py b/metrics/nll.py index a4d09983..03cc09e0 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -4,7 +4,7 @@ # @FileName : nll.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -34,11 +34,10 @@ def get_score(self): if self.leak_dis is not None: # For LeakGAN return self.cal_nll_with_leak_dis(self.model, self.data_loader, self.leak_dis, self.gpu) - elif self.label_i is not None: # For category text generation + if self.label_i is not None: # For category text generation return self.cal_nll_with_label(self.model, self.data_loader, self.label_i, self.criterion, self.gpu) - else: - return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) + return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) def reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): self.model = model @@ -51,7 +50,7 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): """NLL score for general text generation model.""" total_loss = 0 with torch.no_grad(): - for i, data in enumerate(data_loader): + for i, data in enumerate(tqdm(data_loader, desc=self.name)): inp, target = data['input'], data['target'] if gpu: inp, target = inp.cuda(), target.cuda() @@ -68,7 +67,7 @@ def cal_nll_with_label(model, data_loader, label_i, criterion, gpu=cfg.CUDA): assert type(label_i) == int, 'missing label' total_loss = 0 with torch.no_grad(): - for i, data in enumerate(data_loader): + for i, data in enumerate(tqdm(data_loader, desc=self.name)): inp, target = data['input'], data['target'] label = torch.LongTensor([label_i] * data_loader.batch_size) if gpu: @@ -88,7 +87,7 @@ def cal_nll_with_leak_dis(model, data_loader, leak_dis, gpu=cfg.CUDA): """NLL score for LeakGAN.""" total_loss = 0 with torch.no_grad(): - for i, data in enumerate(data_loader): + for i, data in enumerate(tqdm(data_loader, desc=self.name)): inp, target = data['input'], data['target'] if gpu: inp, target = inp.cuda(), target.cuda() diff --git a/metrics/ppl.py b/metrics/ppl.py index a1049a7d..764d0a7e 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -4,7 +4,7 @@ # @FileName : ppl.py # @Time : Created at 2019/12/5 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import math import string diff --git a/utils/data_loader.py b/utils/data_loader.py index bbe15180..36c2764e 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -9,12 +9,19 @@ import random +import numpy as np import torch from torch.utils.data import Dataset, DataLoader from tqdm import tqdm, trange import config as cfg -from utils.text_process import * +from utils.text_process import ( + tokens_to_tensor, + get_tokenlized, + load_dict, + load_test_dict, + vectorize_sentence, +) class GANDataset(Dataset): @@ -141,19 +148,21 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): self.labels = torch.tensor(labels, dtype=int) - self.tokenized = tokenized + self.tokenized = np.array(tokenized) self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size + self.w2v = w2v + self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) - print('dataset random texts examples', [txt for txt in self.texts][:3]) + print('dataset random texts examples\n', '\n'.join([txt for txt in self.texts][:5])) def vectorize_batch(self, tokenized): vectors = [ vectorize_sentence( tokens, - w2v, + self.w2v, target_len=cfg.target_len, padding_token = cfg.padding_token, ) @@ -165,10 +174,10 @@ def vectorize_batch(self, tokenized): def __iter__(self): permutation = torch.randperm(len(self)) - self.vectors = self.vectors[permutation] + self.tokenized = self.tokenized[permutation] self.labels = self.labels[permutation] - for _ in range(self.batches_per_epoch): + for _ in trange(self.batches_per_epoch, leave=False, desc='epoch train'): index = 0 index += self.batch_size if index > len(self): @@ -176,17 +185,17 @@ def __iter__(self): yield ( torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), torch.cat(( - self.vectorize_batch(self.vectors[index - self.batch_size: index]), - self.vectorize_batch(self.vectors[:index-len(self)]) + self.vectorize_batch(self.tokenized[index - self.batch_size: index]), + self.vectorize_batch(self.tokenized[:index-len(self)]) )) ) index = index % len(self) else: - yield self.labels[index - self.batch_size: index], self.vectorize_batch(self.vectors[index - self.batch_size: index]) + yield self.labels[index - self.batch_size: index], self.vectorize_batch(self.tokenized[index - self.batch_size: index]) def __len__(self): - return len(self.vectors) + return len(self.tokenized) def is_this_message_in_dataset(self, text): return text in self.texts From 64747b52214d8e280baf98b8636de6b2c81826d8 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 09:31:39 +0000 Subject: [PATCH 37/81] decreased sample num to form 1000 to 100 --- config.py | 2 +- models/FixemGAN_G.py | 2 +- run/run_catgan.py | 2 -- run/run_cot.py | 4 +--- run/run_dgsan.py | 4 +--- run/run_dpgan.py | 2 -- run/run_jsdgan.py | 4 +--- run/run_leakgan.py | 4 +--- run/run_maligan.py | 4 +--- run/run_sentigan.py | 2 -- run/run_seqgan.py | 4 +--- utils/data_loader.py | 2 +- 12 files changed, 9 insertions(+), 27 deletions(-) diff --git a/config.py b/config.py index 62587d1b..9ad74804 100644 --- a/config.py +++ b/config.py @@ -74,7 +74,7 @@ temperature = 1 # ===Basic Train=== -samples_num = 1000 # 10000, mr15: 2000, +samples_num = 100 # 10000, mr15: 2000, MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 inter_epoch = 15 # LeakGAN-10 diff --git a/models/FixemGAN_G.py b/models/FixemGAN_G.py index 959eefcc..ffab1f31 100644 --- a/models/FixemGAN_G.py +++ b/models/FixemGAN_G.py @@ -142,7 +142,7 @@ def sample(self, num_samples, batch_size, label_i = 'random', start_letter=cfg.s fakes = self.forward(*noise) fakes = fakes.detach().cpu().numpy() assert len(fakes.shape) == 3 - return [self.recover_sentence(fake) for fake in fakes] + return [self.recover_sentence(fake) for fake in tqdm(fakes, desc='recovering messages')] def recover_sentence(self, fake): fake = fake.T diff --git a/run/run_catgan.py b/run/run_catgan.py index edda3f01..69d4f08f 100644 --- a/run/run_catgan.py +++ b/run/run_catgan.py @@ -67,7 +67,6 @@ model_type = 'vanilla' gen_init = 'truncated_normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -136,7 +135,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_cot.py b/run/run_cot.py index 92f4f03c..491f984f 100644 --- a/run/run_cot.py +++ b/run/run_cot.py @@ -4,7 +4,7 @@ # @FileName : run_cot.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -52,7 +52,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'normal' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -98,7 +97,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_dgsan.py b/run/run_dgsan.py index b7085283..f096d87d 100644 --- a/run/run_dgsan.py +++ b/run/run_dgsan.py @@ -4,7 +4,7 @@ # @FileName : run_dgsan.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -51,7 +51,6 @@ data_shuffle = int(False) model_type = 'vanilla' gen_init = 'truncated_normal' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -92,7 +91,6 @@ '--shuffle', data_shuffle, '--model_type', model_type, '--gen_init', gen_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_dpgan.py b/run/run_dpgan.py index 60529c3e..fe40c1e5 100644 --- a/run/run_dpgan.py +++ b/run/run_dpgan.py @@ -52,7 +52,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -105,7 +104,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_jsdgan.py b/run/run_jsdgan.py index 5a6e2111..9b1ee7d5 100644 --- a/run/run_jsdgan.py +++ b/run/run_jsdgan.py @@ -4,7 +4,7 @@ # @FileName : run_jsdgan.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -50,7 +50,6 @@ data_shuffle = int(False) model_type = 'vanilla' gen_init = 'normal' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -91,7 +90,6 @@ '--shuffle', data_shuffle, '--model_type', model_type, '--gen_init', gen_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_leakgan.py b/run/run_leakgan.py index bb45d796..327719c9 100644 --- a/run/run_leakgan.py +++ b/run/run_leakgan.py @@ -4,7 +4,7 @@ # @FileName : run_leakgan.py # @Time : Created at 2019-05-27 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -53,7 +53,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.0015 @@ -109,7 +108,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_maligan.py b/run/run_maligan.py index 24424213..22b021e9 100644 --- a/run/run_maligan.py +++ b/run/run_maligan.py @@ -4,7 +4,7 @@ # @FileName : run_maligan.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -52,7 +52,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -105,7 +104,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_sentigan.py b/run/run_sentigan.py index fb5c0e38..1be52a6b 100644 --- a/run/run_sentigan.py +++ b/run/run_sentigan.py @@ -55,7 +55,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -111,7 +110,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/run/run_seqgan.py b/run/run_seqgan.py index 7d98b9cd..4a6f00da 100644 --- a/run/run_seqgan.py +++ b/run/run_seqgan.py @@ -4,7 +4,7 @@ # @FileName : run_seqgan.py # @Time : Created at 2019-05-27 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -52,7 +52,6 @@ model_type = 'vanilla' gen_init = 'normal' dis_init = 'uniform' -samples_num = 10000 batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -105,7 +104,6 @@ '--model_type', model_type, '--gen_init', gen_init, '--dis_init', dis_init, - '--samples_num', samples_num, '--batch_size', batch_size, '--max_seq_len', max_seq_len, '--gen_lr', gen_lr, diff --git a/utils/data_loader.py b/utils/data_loader.py index 36c2764e..ee44bc93 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -177,7 +177,7 @@ def __iter__(self): self.tokenized = self.tokenized[permutation] self.labels = self.labels[permutation] - for _ in trange(self.batches_per_epoch, leave=False, desc='epoch train'): + for _ in range(self.batches_per_epoch): index = 0 index += self.batch_size if index > len(self): From 38b7c0ffa937481936bb80178463bad2a97bff10 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 10:06:14 +0000 Subject: [PATCH 38/81] separating generators and discriminators --- models/{ => discriminators}/CatGAN_D.py | 0 models/{ => discriminators}/CoT_D.py | 0 models/{ => discriminators}/DPGAN_D.py | 0 models/{ => discriminators}/EvoGAN_D.py | 0 models/{ => discriminators}/FixemGAN_D.py | 0 models/{ => discriminators}/LeakGAN_D.py | 0 models/{ => discriminators}/MaliGAN_D.py | 0 models/{ => discriminators}/RelGAN_D.py | 0 models/{ => discriminators}/SentiGAN_D.py | 0 models/{ => discriminators}/SeqGAN_D.py | 0 models/{ => discriminators}/discriminator.py | 0 models/{ => generators}/CatGAN_G.py | 0 models/{ => generators}/CoT_G.py | 0 models/{ => generators}/DGSAN_G.py | 0 models/{ => generators}/DPGAN_G.py | 0 models/{ => generators}/EvoGAN_G.py | 0 models/{ => generators}/FixemGAN_G.py | 0 models/{ => generators}/JSDGAN_G.py | 0 models/{ => generators}/LeakGAN_G.py | 0 models/{ => generators}/MaliGAN_G.py | 0 models/{ => generators}/Oracle.py | 0 models/{ => generators}/RelGAN_G.py | 0 models/{ => generators}/SentiGAN_G.py | 0 models/{ => generators}/SeqGAN_G.py | 0 models/{ => generators}/generator.py | 0 models/{ => generators}/relational_rnn_general.py | 0 26 files changed, 0 insertions(+), 0 deletions(-) rename models/{ => discriminators}/CatGAN_D.py (100%) rename models/{ => discriminators}/CoT_D.py (100%) rename models/{ => discriminators}/DPGAN_D.py (100%) rename models/{ => discriminators}/EvoGAN_D.py (100%) rename models/{ => discriminators}/FixemGAN_D.py (100%) rename models/{ => discriminators}/LeakGAN_D.py (100%) rename models/{ => discriminators}/MaliGAN_D.py (100%) rename models/{ => discriminators}/RelGAN_D.py (100%) rename models/{ => discriminators}/SentiGAN_D.py (100%) rename models/{ => discriminators}/SeqGAN_D.py (100%) rename models/{ => discriminators}/discriminator.py (100%) rename models/{ => generators}/CatGAN_G.py (100%) rename models/{ => generators}/CoT_G.py (100%) rename models/{ => generators}/DGSAN_G.py (100%) rename models/{ => generators}/DPGAN_G.py (100%) rename models/{ => generators}/EvoGAN_G.py (100%) rename models/{ => generators}/FixemGAN_G.py (100%) rename models/{ => generators}/JSDGAN_G.py (100%) rename models/{ => generators}/LeakGAN_G.py (100%) rename models/{ => generators}/MaliGAN_G.py (100%) rename models/{ => generators}/Oracle.py (100%) rename models/{ => generators}/RelGAN_G.py (100%) rename models/{ => generators}/SentiGAN_G.py (100%) rename models/{ => generators}/SeqGAN_G.py (100%) rename models/{ => generators}/generator.py (100%) rename models/{ => generators}/relational_rnn_general.py (100%) diff --git a/models/CatGAN_D.py b/models/discriminators/CatGAN_D.py similarity index 100% rename from models/CatGAN_D.py rename to models/discriminators/CatGAN_D.py diff --git a/models/CoT_D.py b/models/discriminators/CoT_D.py similarity index 100% rename from models/CoT_D.py rename to models/discriminators/CoT_D.py diff --git a/models/DPGAN_D.py b/models/discriminators/DPGAN_D.py similarity index 100% rename from models/DPGAN_D.py rename to models/discriminators/DPGAN_D.py diff --git a/models/EvoGAN_D.py b/models/discriminators/EvoGAN_D.py similarity index 100% rename from models/EvoGAN_D.py rename to models/discriminators/EvoGAN_D.py diff --git a/models/FixemGAN_D.py b/models/discriminators/FixemGAN_D.py similarity index 100% rename from models/FixemGAN_D.py rename to models/discriminators/FixemGAN_D.py diff --git a/models/LeakGAN_D.py b/models/discriminators/LeakGAN_D.py similarity index 100% rename from models/LeakGAN_D.py rename to models/discriminators/LeakGAN_D.py diff --git a/models/MaliGAN_D.py b/models/discriminators/MaliGAN_D.py similarity index 100% rename from models/MaliGAN_D.py rename to models/discriminators/MaliGAN_D.py diff --git a/models/RelGAN_D.py b/models/discriminators/RelGAN_D.py similarity index 100% rename from models/RelGAN_D.py rename to models/discriminators/RelGAN_D.py diff --git a/models/SentiGAN_D.py b/models/discriminators/SentiGAN_D.py similarity index 100% rename from models/SentiGAN_D.py rename to models/discriminators/SentiGAN_D.py diff --git a/models/SeqGAN_D.py b/models/discriminators/SeqGAN_D.py similarity index 100% rename from models/SeqGAN_D.py rename to models/discriminators/SeqGAN_D.py diff --git a/models/discriminator.py b/models/discriminators/discriminator.py similarity index 100% rename from models/discriminator.py rename to models/discriminators/discriminator.py diff --git a/models/CatGAN_G.py b/models/generators/CatGAN_G.py similarity index 100% rename from models/CatGAN_G.py rename to models/generators/CatGAN_G.py diff --git a/models/CoT_G.py b/models/generators/CoT_G.py similarity index 100% rename from models/CoT_G.py rename to models/generators/CoT_G.py diff --git a/models/DGSAN_G.py b/models/generators/DGSAN_G.py similarity index 100% rename from models/DGSAN_G.py rename to models/generators/DGSAN_G.py diff --git a/models/DPGAN_G.py b/models/generators/DPGAN_G.py similarity index 100% rename from models/DPGAN_G.py rename to models/generators/DPGAN_G.py diff --git a/models/EvoGAN_G.py b/models/generators/EvoGAN_G.py similarity index 100% rename from models/EvoGAN_G.py rename to models/generators/EvoGAN_G.py diff --git a/models/FixemGAN_G.py b/models/generators/FixemGAN_G.py similarity index 100% rename from models/FixemGAN_G.py rename to models/generators/FixemGAN_G.py diff --git a/models/JSDGAN_G.py b/models/generators/JSDGAN_G.py similarity index 100% rename from models/JSDGAN_G.py rename to models/generators/JSDGAN_G.py diff --git a/models/LeakGAN_G.py b/models/generators/LeakGAN_G.py similarity index 100% rename from models/LeakGAN_G.py rename to models/generators/LeakGAN_G.py diff --git a/models/MaliGAN_G.py b/models/generators/MaliGAN_G.py similarity index 100% rename from models/MaliGAN_G.py rename to models/generators/MaliGAN_G.py diff --git a/models/Oracle.py b/models/generators/Oracle.py similarity index 100% rename from models/Oracle.py rename to models/generators/Oracle.py diff --git a/models/RelGAN_G.py b/models/generators/RelGAN_G.py similarity index 100% rename from models/RelGAN_G.py rename to models/generators/RelGAN_G.py diff --git a/models/SentiGAN_G.py b/models/generators/SentiGAN_G.py similarity index 100% rename from models/SentiGAN_G.py rename to models/generators/SentiGAN_G.py diff --git a/models/SeqGAN_G.py b/models/generators/SeqGAN_G.py similarity index 100% rename from models/SeqGAN_G.py rename to models/generators/SeqGAN_G.py diff --git a/models/generator.py b/models/generators/generator.py similarity index 100% rename from models/generator.py rename to models/generators/generator.py diff --git a/models/relational_rnn_general.py b/models/generators/relational_rnn_general.py similarity index 100% rename from models/relational_rnn_general.py rename to models/generators/relational_rnn_general.py From 49cccfe1104cf8877139f4b10889f4a934cf2f70 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 10:15:02 +0000 Subject: [PATCH 39/81] moved code references for generator and discriminator --- instructor/oracle_data/catgan_instructor.py | 4 ++-- instructor/oracle_data/cot_instructor.py | 6 +++--- instructor/oracle_data/dgsan_instructor.py | 4 ++-- instructor/oracle_data/dpgan_instructor.py | 4 ++-- instructor/oracle_data/evogan_instructor.py | 4 ++-- instructor/oracle_data/fixem_instructor.py | 4 ++-- instructor/oracle_data/instructor.py | 2 +- instructor/oracle_data/jsdgan_instructor.py | 4 ++-- instructor/oracle_data/leakgan_instructor.py | 6 +++--- instructor/oracle_data/maligan_instructor.py | 6 +++--- instructor/oracle_data/relgan_instructor.py | 6 +++--- instructor/oracle_data/sentigan_instructor.py | 6 +++--- instructor/oracle_data/seqgan_instructor.py | 6 +++--- instructor/real_data/catgan_instructor.py | 4 ++-- instructor/real_data/cot_instructor.py | 6 +++--- instructor/real_data/dgsan_instructor.py | 4 ++-- instructor/real_data/dpgan_instructor.py | 6 +++--- instructor/real_data/evogan_instructor.py | 4 ++-- instructor/real_data/fixem_instructor.py | 4 ++-- instructor/real_data/jsdgan_instructor.py | 4 ++-- instructor/real_data/leakgan_instructor.py | 6 +++--- instructor/real_data/maligan_instructor.py | 6 +++--- instructor/real_data/relgan_instructor.py | 6 +++--- instructor/real_data/sentigan_instructor.py | 4 ++-- instructor/real_data/seqgan_instructor.py | 6 +++--- models/discriminators/CatGAN_D.py | 4 ++-- models/discriminators/CoT_D.py | 4 ++-- models/discriminators/DPGAN_D.py | 2 +- models/discriminators/EvoGAN_D.py | 4 ++-- models/discriminators/FixemGAN_D.py | 2 +- models/discriminators/LeakGAN_D.py | 4 ++-- models/discriminators/MaliGAN_D.py | 4 ++-- models/discriminators/RelGAN_D.py | 4 ++-- models/discriminators/SentiGAN_D.py | 4 ++-- models/discriminators/SeqGAN_D.py | 4 ++-- models/generators/CatGAN_G.py | 4 ++-- models/generators/CoT_G.py | 4 ++-- models/generators/DGSAN_G.py | 4 ++-- models/generators/DPGAN_G.py | 2 +- models/generators/EvoGAN_G.py | 4 ++-- models/generators/FixemGAN_G.py | 2 +- models/generators/JSDGAN_G.py | 4 ++-- models/generators/MaliGAN_G.py | 4 ++-- models/generators/Oracle.py | 2 +- models/generators/RelGAN_G.py | 4 ++-- models/generators/SentiGAN_G.py | 4 ++-- models/generators/SeqGAN_G.py | 4 ++-- models/{generators => }/relational_rnn_general.py | 0 48 files changed, 100 insertions(+), 100 deletions(-) rename models/{generators => }/relational_rnn_general.py (100%) diff --git a/instructor/oracle_data/catgan_instructor.py b/instructor/oracle_data/catgan_instructor.py index cbdd75e2..efc2ae60 100644 --- a/instructor/oracle_data/catgan_instructor.py +++ b/instructor/oracle_data/catgan_instructor.py @@ -19,8 +19,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from metrics.nll import NLL -from models.CatGAN_D import CatGAN_D -from models.CatGAN_G import CatGAN_G +from models.descriminators.CatGAN_D import CatGAN_D +from models.generators.CatGAN_G import CatGAN_G from models.Oracle import Oracle from utils.cat_data_loader import CatGenDataIter from utils.data_loader import GenDataIter diff --git a/instructor/oracle_data/cot_instructor.py b/instructor/oracle_data/cot_instructor.py index b435ce22..0e860773 100644 --- a/instructor/oracle_data/cot_instructor.py +++ b/instructor/oracle_data/cot_instructor.py @@ -4,7 +4,7 @@ # @FileName : cot_instructor.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import numpy as np @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.CoT_D import Cot_D -from models.CoT_G import CoT_G +from models.discriminators.CoT_D import Cot_D +from models.generators.CoT_G import CoT_G from utils.data_loader import GenDataIter diff --git a/instructor/oracle_data/dgsan_instructor.py b/instructor/oracle_data/dgsan_instructor.py index 4a9bbdbe..8a810eb5 100644 --- a/instructor/oracle_data/dgsan_instructor.py +++ b/instructor/oracle_data/dgsan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dgsan_instructor.py # @Time : Created at 2020/4/12 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy import numpy as np @@ -16,7 +16,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.DGSAN_G import DGSAN_G +from models.generatos.DGSAN_G import DGSAN_G from utils.data_loader import GenDataIter from utils.helpers import create_oracle diff --git a/instructor/oracle_data/dpgan_instructor.py b/instructor/oracle_data/dpgan_instructor.py index b0fe2597..3c39903c 100644 --- a/instructor/oracle_data/dpgan_instructor.py +++ b/instructor/oracle_data/dpgan_instructor.py @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.DPGAN_D import DPGAN_D -from models.DPGAN_G import DPGAN_G +from models.discriminators.DPGAN_D import DPGAN_D +from models.generators.DPGAN_G import DPGAN_G class DPGANInstructor(BasicInstructor): diff --git a/instructor/oracle_data/evogan_instructor.py b/instructor/oracle_data/evogan_instructor.py index 32793a0b..a7d22a73 100644 --- a/instructor/oracle_data/evogan_instructor.py +++ b/instructor/oracle_data/evogan_instructor.py @@ -18,8 +18,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from metrics.nll import NLL -from models.EvoGAN_D import EvoGAN_D -from models.EvoGAN_G import EvoGAN_G +from models.discriminators.EvoGAN_D import EvoGAN_D +from models.generators.EvoGAN_G import EvoGAN_G from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss from utils.helpers import get_fixed_temperature, get_losses, create_oracle diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index f6c7bfca..239eb5e7 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -21,8 +21,8 @@ from metrics.nll import NLL from utils.create_embeddings import EmbeddingsTrainer, load_embedding from models.Oracle import Oracle -from models.FixemGAN_G import Generator -from models.FixemGAN_D import Discriminator +from models.generators.FixemGAN_G import Generator +from models.discriminators.FixemGAN_D import Discriminator # TO DO: diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 4f48f3e4..1d182db3 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -14,7 +14,7 @@ import config as cfg from metrics.nll import NLL -from models.Oracle import Oracle +from models.generators.Oracle import Oracle from utils.data_loader import GenDataIter from utils.data_utils import create_multi_oracle from utils.helpers import Signal, create_logger, create_oracle, get_fixed_temperature diff --git a/instructor/oracle_data/jsdgan_instructor.py b/instructor/oracle_data/jsdgan_instructor.py index e2264d54..a207eb3a 100644 --- a/instructor/oracle_data/jsdgan_instructor.py +++ b/instructor/oracle_data/jsdgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : JSDGAN_instructor.py # @Time : Created at 2019/11/16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import os import torch @@ -12,7 +12,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.JSDGAN_G import JSDGAN_G +from models.generators.JSDGAN_G import JSDGAN_G from utils.helpers import create_oracle diff --git a/instructor/oracle_data/leakgan_instructor.py b/instructor/oracle_data/leakgan_instructor.py index 45903f84..5e2d7109 100644 --- a/instructor/oracle_data/leakgan_instructor.py +++ b/instructor/oracle_data/leakgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : leakgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.LeakGAN_D import LeakGAN_D -from models.LeakGAN_G import LeakGAN_G +from models.discriminators.LeakGAN_D import LeakGAN_D +from models.generators.LeakGAN_G import LeakGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter from utils.text_process import write_tensor diff --git a/instructor/oracle_data/maligan_instructor.py b/instructor/oracle_data/maligan_instructor.py index c3bd5b39..8d7478f7 100644 --- a/instructor/oracle_data/maligan_instructor.py +++ b/instructor/oracle_data/maligan_instructor.py @@ -4,7 +4,7 @@ # @FileName : maligan_instructor.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.MaliGAN_D import MaliGAN_D -from models.MaliGAN_G import MaliGAN_G +from models.discriminators.MaliGAN_D import MaliGAN_D +from models.generators.MaliGAN_G import MaliGAN_G from utils.data_loader import GenDataIter, DisDataIter diff --git a/instructor/oracle_data/relgan_instructor.py b/instructor/oracle_data/relgan_instructor.py index 8c7b610b..932c51ee 100644 --- a/instructor/oracle_data/relgan_instructor.py +++ b/instructor/oracle_data/relgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : relgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F @@ -13,8 +13,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.RelGAN_D import RelGAN_D -from models.RelGAN_G import RelGAN_G +from models.discriminators.RelGAN_D import RelGAN_D +from models.generators.RelGAN_G import RelGAN_G from utils.helpers import get_fixed_temperature, get_losses diff --git a/instructor/oracle_data/sentigan_instructor.py b/instructor/oracle_data/sentigan_instructor.py index 17dd6631..42be59a3 100644 --- a/instructor/oracle_data/sentigan_instructor.py +++ b/instructor/oracle_data/sentigan_instructor.py @@ -4,7 +4,7 @@ # @FileName : sentigan_instructor.py # @Time : Created at 2019-07-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import os @@ -14,8 +14,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from models.Oracle import Oracle -from models.SentiGAN_D import SentiGAN_D -from models.SentiGAN_G import SentiGAN_G +from models.discriminators.SentiGAN_D import SentiGAN_D +from models.generators.SentiGAN_G import SentiGAN_G from utils import rollout from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter diff --git a/instructor/oracle_data/seqgan_instructor.py b/instructor/oracle_data/seqgan_instructor.py index 00046b80..b7f64676 100644 --- a/instructor/oracle_data/seqgan_instructor.py +++ b/instructor/oracle_data/seqgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : seqgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.SeqGAN_D import SeqGAN_D -from models.SeqGAN_G import SeqGAN_G +from models.discriminators.SeqGAN_D import SeqGAN_D +from models.generators.SeqGAN_G import SeqGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter diff --git a/instructor/real_data/catgan_instructor.py b/instructor/real_data/catgan_instructor.py index 094c0c9c..e6bd3d9a 100644 --- a/instructor/real_data/catgan_instructor.py +++ b/instructor/real_data/catgan_instructor.py @@ -17,8 +17,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor from metrics.nll import NLL -from models.CatGAN_D import CatGAN_D, CatGAN_C -from models.CatGAN_G import CatGAN_G +from models.discriminators.CatGAN_D import CatGAN_D, CatGAN_C +from models.generators.CatGAN_G import CatGAN_G from utils.cat_data_loader import CatGenDataIter from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss diff --git a/instructor/real_data/cot_instructor.py b/instructor/real_data/cot_instructor.py index b7ad3d24..e92fb9eb 100644 --- a/instructor/real_data/cot_instructor.py +++ b/instructor/real_data/cot_instructor.py @@ -4,7 +4,7 @@ # @FileName : cot_instructor.py # @Time : Created at 2020/4/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -15,8 +15,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.CoT_D import Cot_D -from models.CoT_G import CoT_G +from models.discriminators.CoT_D import Cot_D +from models.generators.CoT_G import CoT_G from utils.data_loader import GenDataIter diff --git a/instructor/real_data/dgsan_instructor.py b/instructor/real_data/dgsan_instructor.py index 5b018c87..7018b452 100644 --- a/instructor/real_data/dgsan_instructor.py +++ b/instructor/real_data/dgsan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dgsan_instructor.py # @Time : Created at 2020/4/16 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy @@ -16,7 +16,7 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.DGSAN_G import DGSAN_G +from models.generators.DGSAN_G import DGSAN_G from utils.data_loader import GenDataIter diff --git a/instructor/real_data/dpgan_instructor.py b/instructor/real_data/dpgan_instructor.py index 40c95e15..071a382c 100644 --- a/instructor/real_data/dpgan_instructor.py +++ b/instructor/real_data/dpgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : dpgan_instructor.py # @Time : Created at 2019/12/21 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.DPGAN_D import DPGAN_D -from models.DPGAN_G import DPGAN_G +from models.discriminators.DPGAN_D import DPGAN_D +from models.generators.DPGAN_G import DPGAN_G class DPGANInstructor(BasicInstructor): diff --git a/instructor/real_data/evogan_instructor.py b/instructor/real_data/evogan_instructor.py index 6e410d47..f8aa1211 100644 --- a/instructor/real_data/evogan_instructor.py +++ b/instructor/real_data/evogan_instructor.py @@ -17,8 +17,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor from metrics.nll import NLL -from models.EvoGAN_D import EvoGAN_D -from models.EvoGAN_G import EvoGAN_G +from models.discriminators.EvoGAN_D import EvoGAN_D +from models.generators.EvoGAN_G import EvoGAN_G from utils.data_loader import GenDataIter from utils.gan_loss import GANLoss from utils.helpers import get_fixed_temperature, get_losses diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 5abe120b..526ae940 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -18,8 +18,8 @@ from utils.cat_data_loader import CatClasDataIter from utils.nn_helpers import create_noise, number_of_parameters from utils.create_embeddings import EmbeddingsTrainer, load_embedding -from models.FixemGAN_G import Generator -from models.FixemGAN_D import Discriminator +from models.generators.FixemGAN_G import Generator +from models.discriminators.FixemGAN_D import Discriminator class FixemGANInstructor(BasicInstructor): diff --git a/instructor/real_data/jsdgan_instructor.py b/instructor/real_data/jsdgan_instructor.py index 6c3e58b0..be9f782f 100644 --- a/instructor/real_data/jsdgan_instructor.py +++ b/instructor/real_data/jsdgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : JSDGAN_instructor.py # @Time : Created at 2019/11/25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,7 +12,7 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.JSDGAN_G import JSDGAN_G +from models.generators.JSDGAN_G import JSDGAN_G class JSDGANInstructor(BasicInstructor): diff --git a/instructor/real_data/leakgan_instructor.py b/instructor/real_data/leakgan_instructor.py index 922a203f..bfff5993 100644 --- a/instructor/real_data/leakgan_instructor.py +++ b/instructor/real_data/leakgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : leakgan_instructor.py # @Time : Created at 2019-06-05 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.LeakGAN_D import LeakGAN_D -from models.LeakGAN_G import LeakGAN_G +from models.discriminators.LeakGAN_D import LeakGAN_D +from models.generators.LeakGAN_G import LeakGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter from utils.text_process import tensor_to_tokens, write_tokens diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index 65258201..a27356f8 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -4,7 +4,7 @@ # @FileName : maligan_instructor.py # @Time : Created at 2019/11/29 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. @@ -14,8 +14,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.MaliGAN_D import MaliGAN_D -from models.MaliGAN_G import MaliGAN_G +from models.discriminators.MaliGAN_D import MaliGAN_D +from models.generators.MaliGAN_G import MaliGAN_G from utils.data_loader import GenDataIter, DisDataIter diff --git a/instructor/real_data/relgan_instructor.py b/instructor/real_data/relgan_instructor.py index 17df9ecc..391b961b 100644 --- a/instructor/real_data/relgan_instructor.py +++ b/instructor/real_data/relgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : relgan_instructor.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -14,8 +14,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.RelGAN_D import RelGAN_D -from models.RelGAN_G import RelGAN_G +from models.discriminators.RelGAN_D import RelGAN_D +from models.generators.RelGAN_G import RelGAN_G from utils.helpers import get_fixed_temperature, get_losses diff --git a/instructor/real_data/sentigan_instructor.py b/instructor/real_data/sentigan_instructor.py index 3d7ecd50..952df9da 100644 --- a/instructor/real_data/sentigan_instructor.py +++ b/instructor/real_data/sentigan_instructor.py @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.SentiGAN_D import SentiGAN_D, SentiGAN_C -from models.SentiGAN_G import SentiGAN_G +from models.discriminators.SentiGAN_D import SentiGAN_D, SentiGAN_C +from models.generators.SentiGAN_G import SentiGAN_G from utils import rollout from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter diff --git a/instructor/real_data/seqgan_instructor.py b/instructor/real_data/seqgan_instructor.py index 061a2607..dccb4a2f 100644 --- a/instructor/real_data/seqgan_instructor.py +++ b/instructor/real_data/seqgan_instructor.py @@ -4,7 +4,7 @@ # @FileName : seqgan_instructor.py # @Time : Created at 2019-06-05 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,8 +12,8 @@ import config as cfg from instructor.real_data.instructor import BasicInstructor -from models.SeqGAN_D import SeqGAN_D -from models.SeqGAN_G import SeqGAN_G +from models.discriminators.SeqGAN_D import SeqGAN_D +from models.generators.SeqGAN_G import SeqGAN_G from utils import rollout from utils.data_loader import GenDataIter, DisDataIter diff --git a/models/discriminators/CatGAN_D.py b/models/discriminators/CatGAN_D.py index d9f5310b..2e841b71 100644 --- a/models/discriminators/CatGAN_D.py +++ b/models/discriminators/CatGAN_D.py @@ -4,14 +4,14 @@ # @FileName : CatGAN_D.py # @Time : Created at 2019-05-28 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import torch.nn.functional as F -from models.discriminator import CNNDiscriminator, CNNClassifier +from models.discriminators.discriminator import CNNDiscriminator, CNNClassifier dis_filter_sizes = [2, 3, 4, 5] dis_num_filters = [300, 300, 300, 300] diff --git a/models/discriminators/CoT_D.py b/models/discriminators/CoT_D.py index 325c3adc..c008a39d 100644 --- a/models/discriminators/CoT_D.py +++ b/models/discriminators/CoT_D.py @@ -4,13 +4,13 @@ # @FileName : CoT_Medicator.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class Cot_D(LSTMGenerator): diff --git a/models/discriminators/DPGAN_D.py b/models/discriminators/DPGAN_D.py index 4b8827af..ace54904 100644 --- a/models/discriminators/DPGAN_D.py +++ b/models/discriminators/DPGAN_D.py @@ -11,7 +11,7 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from utils.data_loader import GenDataIter diff --git a/models/discriminators/EvoGAN_D.py b/models/discriminators/EvoGAN_D.py index 6904ca7c..7d13ed95 100644 --- a/models/discriminators/EvoGAN_D.py +++ b/models/discriminators/EvoGAN_D.py @@ -4,14 +4,14 @@ # @FileName : EvoGAN_D.py # @Time : Created at 2019-07-09 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import torch.nn.functional as F -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [2, 3, 4, 5] dis_num_filters = [300, 300, 300, 300] diff --git a/models/discriminators/FixemGAN_D.py b/models/discriminators/FixemGAN_D.py index 6f609f76..6eea6ae8 100644 --- a/models/discriminators/FixemGAN_D.py +++ b/models/discriminators/FixemGAN_D.py @@ -2,7 +2,7 @@ import config as cfg from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten, Dummy -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator class Discriminator(nn.Module): def __init__(self, complexity): diff --git a/models/discriminators/LeakGAN_D.py b/models/discriminators/LeakGAN_D.py index afa398db..86c46800 100644 --- a/models/discriminators/LeakGAN_D.py +++ b/models/discriminators/LeakGAN_D.py @@ -4,10 +4,10 @@ # @FileName : LeakGAN_D.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] diff --git a/models/discriminators/MaliGAN_D.py b/models/discriminators/MaliGAN_D.py index a23eb8a0..b132c8d1 100644 --- a/models/discriminators/MaliGAN_D.py +++ b/models/discriminators/MaliGAN_D.py @@ -4,10 +4,10 @@ # @FileName : MaliGAN_D.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] diff --git a/models/discriminators/RelGAN_D.py b/models/discriminators/RelGAN_D.py index b62334d6..88435d5a 100644 --- a/models/discriminators/RelGAN_D.py +++ b/models/discriminators/RelGAN_D.py @@ -4,14 +4,14 @@ # @FileName : RelGAN_D.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import torch.nn.functional as F -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [2, 3, 4, 5] dis_num_filters = [300, 300, 300, 300] diff --git a/models/discriminators/SentiGAN_D.py b/models/discriminators/SentiGAN_D.py index 7daccfb7..f6668ec4 100644 --- a/models/discriminators/SentiGAN_D.py +++ b/models/discriminators/SentiGAN_D.py @@ -4,12 +4,12 @@ # @FileName : SentiGAN_D.py # @Time : Created at 2019-07-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch.nn as nn -from models.discriminator import CNNDiscriminator, CNNClassifier +from models.discriminators.discriminator import CNNDiscriminator, CNNClassifier dis_filter_sizes = [2, 3, 4, 5] dis_num_filters = [200, 200, 200, 200] diff --git a/models/discriminators/SeqGAN_D.py b/models/discriminators/SeqGAN_D.py index 9e63b823..70de10d1 100644 --- a/models/discriminators/SeqGAN_D.py +++ b/models/discriminators/SeqGAN_D.py @@ -4,10 +4,10 @@ # @FileName : SeqGAN_D.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.discriminator import CNNDiscriminator +from models.discriminators.discriminator import CNNDiscriminator dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] diff --git a/models/generators/CatGAN_G.py b/models/generators/CatGAN_G.py index 60493e81..c4e2f7a6 100644 --- a/models/generators/CatGAN_G.py +++ b/models/generators/CatGAN_G.py @@ -4,7 +4,7 @@ # @FileName : CatGAN_G.py # @Time : Created at 2019-07-18 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,7 +12,7 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory diff --git a/models/generators/CoT_G.py b/models/generators/CoT_G.py index 357ac73c..50999e8f 100644 --- a/models/generators/CoT_G.py +++ b/models/generators/CoT_G.py @@ -4,13 +4,13 @@ # @FileName : CoT_G.py # @Time : Created at 2020/4/20 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class CoT_G(LSTMGenerator): diff --git a/models/generators/DGSAN_G.py b/models/generators/DGSAN_G.py index 40646e0b..5deb53e8 100644 --- a/models/generators/DGSAN_G.py +++ b/models/generators/DGSAN_G.py @@ -4,10 +4,10 @@ # @FileName : DGSAN_G.py # @Time : Created at 2020/4/12 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class DGSAN_G(LSTMGenerator): diff --git a/models/generators/DPGAN_G.py b/models/generators/DPGAN_G.py index 121b62a3..f9d415b7 100644 --- a/models/generators/DPGAN_G.py +++ b/models/generators/DPGAN_G.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class DPGAN_G(LSTMGenerator): diff --git a/models/generators/EvoGAN_G.py b/models/generators/EvoGAN_G.py index b6ad52ea..8bb2a358 100644 --- a/models/generators/EvoGAN_G.py +++ b/models/generators/EvoGAN_G.py @@ -4,7 +4,7 @@ # @FileName : EvoGAN_G.py # @Time : Created at 2019-07-09 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -12,7 +12,7 @@ import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory diff --git a/models/generators/FixemGAN_G.py b/models/generators/FixemGAN_G.py index ffab1f31..5b0277a5 100644 --- a/models/generators/FixemGAN_G.py +++ b/models/generators/FixemGAN_G.py @@ -15,7 +15,7 @@ ) import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator diff --git a/models/generators/JSDGAN_G.py b/models/generators/JSDGAN_G.py index a35cf35b..32c6432c 100644 --- a/models/generators/JSDGAN_G.py +++ b/models/generators/JSDGAN_G.py @@ -4,14 +4,14 @@ # @FileName : JSDGAN_G.py # @Time : Created at 2019/11/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class JSDGAN_G(LSTMGenerator): diff --git a/models/generators/MaliGAN_G.py b/models/generators/MaliGAN_G.py index 7baf0be6..35b64591 100644 --- a/models/generators/MaliGAN_G.py +++ b/models/generators/MaliGAN_G.py @@ -4,13 +4,13 @@ # @FileName : MaliGAN_G.py # @Time : Created at 2019/10/17 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class MaliGAN_G(LSTMGenerator): diff --git a/models/generators/Oracle.py b/models/generators/Oracle.py index 89c4c0e5..d637559a 100644 --- a/models/generators/Oracle.py +++ b/models/generators/Oracle.py @@ -7,7 +7,7 @@ # @Description : # Copyrights (C) 2018. All Rights Reserved. -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class Oracle(LSTMGenerator): diff --git a/models/generators/RelGAN_G.py b/models/generators/RelGAN_G.py index c8294269..1044b988 100644 --- a/models/generators/RelGAN_G.py +++ b/models/generators/RelGAN_G.py @@ -4,14 +4,14 @@ # @FileName : RelGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn as nn import torch.nn.functional as F import config as cfg -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator from models.relational_rnn_general import RelationalMemory diff --git a/models/generators/SentiGAN_G.py b/models/generators/SentiGAN_G.py index f0fa0e45..a7c8f360 100644 --- a/models/generators/SentiGAN_G.py +++ b/models/generators/SentiGAN_G.py @@ -4,14 +4,14 @@ # @FileName : SentiGAN_G.py # @Time : Created at 2019-07-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class SentiGAN_G(LSTMGenerator): diff --git a/models/generators/SeqGAN_G.py b/models/generators/SeqGAN_G.py index 86dd7c86..0414d816 100644 --- a/models/generators/SeqGAN_G.py +++ b/models/generators/SeqGAN_G.py @@ -4,13 +4,13 @@ # @FileName : SeqGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch import torch.nn.functional as F -from models.generator import LSTMGenerator +from models.generators.generator import LSTMGenerator class SeqGAN_G(LSTMGenerator): diff --git a/models/generators/relational_rnn_general.py b/models/relational_rnn_general.py similarity index 100% rename from models/generators/relational_rnn_general.py rename to models/relational_rnn_general.py From 825dd20a76fe9e08c992ede317559fe158070a5f Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 10:17:38 +0000 Subject: [PATCH 40/81] typo --- config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config.py b/config.py index 9ad74804..1396454b 100644 --- a/config.py +++ b/config.py @@ -195,8 +195,8 @@ pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) -emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/real_data/' -pretrain_embedding_path = emebedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) +embedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' +pretrain_embedding_path = embedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) texts_pile = 'dataset/' # do not include testdata signal_file = 'run_signal.txt' @@ -360,8 +360,8 @@ def init_param(opt): samples_num) pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, samples_num) - emebedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' - pretrain_embedding_path = emebedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) + embedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' + pretrain_embedding_path = embedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) # Assertion assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( From 38861322b5a7e07a8045902afe5ae47910520953 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 8 Feb 2023 10:33:32 +0000 Subject: [PATCH 41/81] deleted intel from metrics and created simgle small sample hyperparam --- config.py | 1 + instructor/real_data/fixem_instructor.py | 4 ++-- instructor/real_data/instructor.py | 4 ++-- metrics/bleu.py | 2 +- metrics/gpt_nll.py | 2 +- metrics/nll.py | 6 +++--- models/generators/FixemGAN_G.py | 2 +- 7 files changed, 11 insertions(+), 10 deletions(-) diff --git a/config.py b/config.py index 1396454b..151f5b9d 100644 --- a/config.py +++ b/config.py @@ -75,6 +75,7 @@ # ===Basic Train=== samples_num = 100 # 10000, mr15: 2000, +small_sample_num = 20 # used for self-blue MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 inter_epoch = 15 # LeakGAN-10 diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 526ae940..3f4b412f 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -147,13 +147,13 @@ def one_more_batch_for_generator( def sample_for_metrics(self): gen_tokens = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_tokens = [sample.split() for sample in gen_tokens] - gen_tokens_s = self.gen.sample(200, 200) + gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size) gen_tokens_s = [sample.split() for sample in gen_tokens_s] return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): gen_tokens = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_tokens = [sample.split() for sample in gen_tokens] - gen_tokens_s = self.gen.sample(200, 200, label_i=label_i) + gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) gen_tokens_s = [sample.split() for sample in gen_tokens_s] return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s, CatClasDataIter([gen_tokens], label_i) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index f3978ff9..8a771162 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -210,14 +210,14 @@ def sample_for_metrics(self): eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size), self.idx2word_dict) return gen_data, gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i), self.idx2word_dict) clas_data = CatClasDataIter([eval_samples], label_i) return gen_data, gen_tokens, gen_tokens_s, clas_data diff --git a/metrics/bleu.py b/metrics/bleu.py index 019e841e..6153ea64 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -98,7 +98,7 @@ def get_bleu_parallel(self, ngram, reference): weight = tuple((1. / ngram for _ in range(ngram))) pool = Pool(os.cpu_count()) result = list() - for idx, hypothesis in enumerate(tqdm(self.test_text[:self.sample_size], desc=self.name)): + for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): result.append(pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight))) score = 0.0 cnt = 0 diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 79973308..72772bac 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -42,7 +42,7 @@ def get_NLL(self, messages, baseline=0): messages = [' '.join(msg) for msg in messages] all_logits = [] - for message in tqdm(messages, desc=self.name): + for message in messages: message = self.tokenizer.eos_token + message + self.tokenizer.eos_token inputs = self.tokenizer(message, return_tensors="pt") logits = self.model(**inputs)[0][0] diff --git a/metrics/nll.py b/metrics/nll.py index 03cc09e0..528d53e3 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -50,7 +50,7 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): """NLL score for general text generation model.""" total_loss = 0 with torch.no_grad(): - for i, data in enumerate(tqdm(data_loader, desc=self.name)): + for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if gpu: inp, target = inp.cuda(), target.cuda() @@ -67,7 +67,7 @@ def cal_nll_with_label(model, data_loader, label_i, criterion, gpu=cfg.CUDA): assert type(label_i) == int, 'missing label' total_loss = 0 with torch.no_grad(): - for i, data in enumerate(tqdm(data_loader, desc=self.name)): + for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] label = torch.LongTensor([label_i] * data_loader.batch_size) if gpu: @@ -87,7 +87,7 @@ def cal_nll_with_leak_dis(model, data_loader, leak_dis, gpu=cfg.CUDA): """NLL score for LeakGAN.""" total_loss = 0 with torch.no_grad(): - for i, data in enumerate(tqdm(data_loader, desc=self.name)): + for i, data in enumerate(data_loader): inp, target = data['input'], data['target'] if gpu: inp, target = inp.cuda(), target.cuda() diff --git a/models/generators/FixemGAN_G.py b/models/generators/FixemGAN_G.py index 5b0277a5..eb3c0b4f 100644 --- a/models/generators/FixemGAN_G.py +++ b/models/generators/FixemGAN_G.py @@ -142,7 +142,7 @@ def sample(self, num_samples, batch_size, label_i = 'random', start_letter=cfg.s fakes = self.forward(*noise) fakes = fakes.detach().cpu().numpy() assert len(fakes.shape) == 3 - return [self.recover_sentence(fake) for fake in tqdm(fakes, desc='recovering messages')] + return [self.recover_sentence(fake) for fake in fakes] def recover_sentence(self, fake): fake = fake.T From e133589455cf6bbd639c85391ede63199ddd4d69 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 22 Feb 2023 13:27:40 +0000 Subject: [PATCH 42/81] fixed README.md --- README.md | 33 +++++++++++++++++++++++---------- assets/model_fixem.png | Bin 0 -> 51852 bytes 2 files changed, 23 insertions(+), 10 deletions(-) create mode 100644 assets/model_fixem.png diff --git a/README.md b/README.md index 8dc28b37..7513d7e4 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ To install, run `pip install -r requirements.txt`. In case of CUDA problems, con ### General Text Generation +- **FixemGAN** - [FixemGAN: Continious Space Text GAN on Fixed Embeddings](https://www.com) - **SeqGAN** - [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/abs/1609.05473) - **LeakGAN** - [Long Text Generation via Adversarial Training with Leaked Information](https://arxiv.org/abs/1709.08624) - **MaliGAN** - [Maximum-Likelihood Augmented Discrete Generative Adversarial Networks](https://arxiv.org/abs/1702.07983) @@ -116,13 +117,25 @@ python3 run_seqgan.py 0 0 ## Implementation Details +### FixemGAN + +- run file: [run_fixem.py](run/run_fixem.py) + +- Instructors: [oracle_data](instructor/oracle_data/fixem_instructor.py), [real_data](instructor/real_data/fixem_instructor.py) + +- Models: [generator](models/generators/FixemGAN_G.py), [discriminator](models/discriminators/FixemGAN_D.py) + +- Structure (from [FixemGAM](https://www.com)) + + ![model_fixem](./assets/model_fixem.png) + ### SeqGAN - run file: [run_seqgan.py](run/run_seqgan.py) - Instructors: [oracle_data](instructor/oracle_data/seqgan_instructor.py), [real_data](instructor/real_data/seqgan_instructor.py) -- Models: [generator](models/SeqGAN_G.py), [discriminator](models/SeqGAN_D.py) +- Models: [generator](models/generators/SeqGAN_G.py), [discriminator](models/discriminators/SeqGAN_D.py) - Structure (from [SeqGAN](https://arxiv.org/pdf/1609.05473.pdf)) @@ -134,7 +147,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/leakgan_instructor.py), [real_data](instructor/real_data/leakgan_instructor.py) -- Models: [generator](models/LeakGAN_G.py), [discriminator](models/LeakGAN_D.py) +- Models: [generator](models/generators/LeakGAN_G.py), [discriminator](models/discriminators/LeakGAN_D.py) - Structure (from [LeakGAN](https://arxiv.org/pdf/1709.08624.pdf)) @@ -146,7 +159,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/maligan_instructor.py), [real_data](instructor/real_data/maligan_instructor.py) -- Models: [generator](models/MaliGAN_G.py), [discriminator](models/MaliGAN_D.py) +- Models: [generator](models/generators/MaliGAN_G.py), [discriminator](models/discriminators/MaliGAN_D.py) - Structure (from my understanding) @@ -158,7 +171,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/jsdgan_instructor.py), [real_data](instructor/real_data/jsdgan_instructor.py) -- Models: [generator](models/JSDGAN_G.py) (No discriminator) +- Models: [generator](models/generators/JSDGAN_G.py) (No discriminator) - Structure (from my understanding) @@ -170,7 +183,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/relgan_instructor.py), [real_data](instructor/real_data/relgan_instructor.py) -- Models: [generator](models/RelGAN_G.py), [discriminator](models/RelGAN_D.py) +- Models: [generator](models/generators/RelGAN_G.py), [discriminator](models/discriminators/RelGAN_D.py) - Structure (from my understanding) @@ -182,7 +195,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/dpgan_instructor.py), [real_data](instructor/real_data/dpgan_instructor.py) -- Models: [generator](models/DPGAN_G.py), [discriminator](models/DPGAN_D.py) +- Models: [generator](models/generators/DPGAN_G.py), [discriminator](models/discriminators/DPGAN_D.py) - Structure (from [DPGAN](https://arxiv.org/abs/1802.01345)) @@ -194,7 +207,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/dgsan_instructor.py), [real_data](instructor/real_data/dgsan_instructor.py) -- Models: [generator](models/DGSAN_G.py), [discriminator](models/DGSAN_D.py) +- Models: [generator](models/generators/DGSAN_G.py), [discriminator](models/discriminators/DGSAN_D.py) ### CoT @@ -202,7 +215,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/cot_instructor.py), [real_data](instructor/real_data/cot_instructor.py) -- Models: [generator](models/CoT_G.py), [discriminator](models/CoT_D.py) +- Models: [generator](models/generators/CoT_G.py), [discriminator](models/discriminators/CoT_D.py) - Structure (from [CoT](https://arxiv.org/abs/1804.03782)) @@ -214,7 +227,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/sentigan_instructor.py), [real_data](instructor/real_data/sentigan_instructor.py) -- Models: [generator](models/SentiGAN_G.py), [discriminator](models/SentiGAN_D.py) +- Models: [generator](models/generators/SentiGAN_G.py), [discriminator](models/discriminators/SentiGAN_D.py) - Structure (from [SentiGAN](https://www.ijcai.org/proceedings/2018/0618.pdf)) @@ -226,7 +239,7 @@ python3 run_seqgan.py 0 0 - Instructors: [oracle_data](instructor/oracle_data/catgan_instructor.py), [real_data](instructor/real_data/catgan_instructor.py) -- Models: [generator](models/CatGAN_G.py), [discriminator](models/CatGAN_D.py) +- Models: [generator](models/generators/CatGAN_G.py), [discriminator](models/discriminators/CatGAN_D.py) - Structure (from [CatGAN](https://arxiv.org/abs/1911.06641)) diff --git a/assets/model_fixem.png b/assets/model_fixem.png new file mode 100644 index 0000000000000000000000000000000000000000..9c3b6892d304de192b30d0d7ab9419e1a1f0a378 GIT binary patch literal 51852 zcmeFZby$>J+czwsfPzxe4H7a)mmnbBDac62&>$cpf&v2q0t$%K3?V8= zryxqBbiHc;-TQgI`+48*JHCJ4ZSW}?i|xO75UqDyiJ#K&U=$9(RLgYk8zQ6 zYLAVpRpM&g_u|rhYS_C!6nG*|qZG;C6v8+@pxC9OD1cNOEUO#ec85SSz%-=7Uu$?c0KZwpkrJ7@?XP4m39sW zEgU1H5mL5S!LKBi-G^GbQpk^9xjf^w4&Ny253&;cYvk&I-D-#$HjQ_F3<;?1c~U=A zHYaV9k-^G>ZzvN(~$i$XW)j|5vxbE6PnPAcI`KSZ@KZY{s87?O}X^Q-@ zlQdjiW$@vW!bm!??WitMrDW*Rrd40=UjsTp@&qD5MO2ev&1`RMGd=7|>1#i)US;dv zGI1vL&9nsCh*sEmfJjDFM*A*;_>*z{Lwne>JRg2dQ!0Y9Ee&(d!G>?N)=EDqb)lryv2tr~oJSJiC8 z_`qtZqKayN_QsH{g;1CgYjBF6cFfYJ=+{N`Bnc|2{Q1H)8E8!wl%fhvZew)?w%@8E z|Be&g-b-1?`iDe7#4!h2w%)0DTPHnEysR~Eri1bzHiJX)k83ml4yz^^mSo7_<=j2J z%*mSqvzm|5Njmj?L8(B zHk>3j8W>mAn{-9VSDeBthT7T|clezL3YqEZAAtDdWo06OK7~UfDQ(v$WiL3Rg+Bz2 z4&EF+f+Uldj#-ge`Eu2Ctw$mZJNceK(AqP%{_2(k;DqqwxiuBy_mEdlo-iroIv?|V zgC3Bk5i3@yRUf*?*hVivVQB_pg6Ul7-j6eH|GxSPqn0WNR@I0^73Pm3mDR`N?ETZQ zkwL`r;w!e1;l3fZA@+lkgjn&@`C{8wjwqM6=p!d%PfYV@u z<%PLS*ym^Q!O3TeV-LiI3;m3`Y#OHsk|pBt_XTyilNWl-F^~2Cd3?Ig>eSq&fM-4b zu(j%htvvz;lkj9TpL*8Kf6N-Q4D=Z!%ckN1b%ecMTk{{YQVkR*1@`N*DdxN@|L{Bf zgx`ot%z1mr|DN~P$$2#-9=)*{t}glKycned*L=ZoM+ThZKV3BD@qggkZb+WJd&ggL zASdKh#SU!Gd=nu2k1R5duo$_+|HpRn@c>bT4L63SVa_Y@_q;Zz=Y^GF&MWxOc^gkQ z_#AKg>OYGejg&Sxhu%LF37x(`Xn=YA&pLnk)g zFIJN1QL0HqrOz*$p)heiKJC!lVcfr}Q+uii5`!?<@K%s)I^JbEeCFA!bJ<^Hevo+Ne+8bSt1!`)R}0#YW^MsM9N>A)jN@PC4xhLI}c<=-`A5ws#s zQM^qDx_vI=H_$bgF{@Vpvnu(?s!n<14g)Wk^`5CyNqClR%;Df&HXo|6*jcLT<8z&m zq`UbLpWC3Ud3!Q2L9bXdUpD+AmG}cT-6GZSAD@e7Cj*b$Y9|9qmF^l z=r(^2wC=3WPS4Jg+3>yf+a*2R-|ld2I*gEyCgrLcBBBv$7M=79MW^vvfw#PW&Bw9S zhhMy^`;I~9nB|#L0>kuFOX$!WcaG}@Wl;oV?40jidC?7CtBMH>w^DWv4-ffW#>CiE z9^-DWjO&?Z8&%tpJZ9H;A$fP;g}yNT!M#4nC|*aXx{$_I+CQT48HmOrC)%jiMWIBu zDDj#ag@SdKuw#Ul^x+*6W9EQlDXr<)j-tJ7NwbA{@8XGM;@o(7eU+JJ2s^Xld>-1UlhL_-!=0O z&kNWikUrXBMcVYl)Hsh8fR%Rk7R%*e&`#GFa;}dhdK_O$^zbY*x>_TzRKm~{vDAVW zz?EqSi0#ZKNPjFhla>rTqT{z}|FV?TmV=}VKxE)$*datjO4|?83u*45j#F3(UiL%2wxgLW;QK=ldwi?$eQd3qZuE=_&N+z(o?4lx>1&6py3f>0I+4zKb1UNwktWJt0wZ6SV;uv$h`<)S{5@DEkkk4Ik)ZnrBtW1JJO*E_< z4*eE&V}0AwCnEiA{8es)1!qg=8ZpCbR_(Eo_B{TY488MbJ?bV?QdyiFR$(=%9Gj#zY_ zd3lKD>5Ig*A&}8N0r(uX-)>)ZYtk9jhAAhwZY1CN8nmWI;1?6w=gUfGIKy1$)bmHy z3|J2%@G;9Wcl}xR4?Q(6An>sJaPiNV)&9K3fA(5+Dr~q0BBxwc_aUCT@o((@n*)P9 z7o~p+{p|mVbZWKEg3VL7(-6^#$ekH-%@YiiLv^FQe!vKGy%|Q`)Q*G#G^e7l2;x!@ zD@dT;svwqaGv2gi+ebpq5Gfz?$=!yvZaH^OaH3D8Xv6}2Mz*>luxM?cV93fn)4GZ4 zXc1y?hfA#K>!-jwJ@|wh5pF;E6kXF`SmBu4-w`;&mdU5kOyJJsm1N`rBHFOkY1A^l z1D53}_9t`?J-I=mn?llb6fkiJYp+0{clX-mGZDUj60|HWrSo?mrp#^YSxBS1L)8fd z?@*tnK!T!A0+(m>?^5QX^nlM+mnb_6Yl2!%?)U=jUQwV8_%|DoeH!!$H_-2tdm0 z!Y&8oCmJ5K4PwJ%FY5~Bw_=B}Y2twX*ImgvHR-@?P$%-gvLGc&<*w)q5<~zh)FU&C6g6>)dhtB??FJ zNf7=mIq}zPW3!k^jzF`pEW&u^FH%6F5)+AkFAbCJZ?AULxy|1H-b;^*hsUZALq1gJ z#`<6m6%HVX#n#gB&X3|eKKFTbA=e3GTKvoRpunWoKl~9=W!qQ(@`{N1iBy{VVK#hV zJ!%R+E^|G&%`Kndi|3ffq2H3=Qs#4e&;7n7?}cE1FIpxi z%}6!?sN;Mr1t5WXhQN81WcD@;jzXn!7aG5L5R?7&l?kmH0@4B=^cQO}7WE0hP*WX@ zf$~4Ulw7sPz+d@;{sM(o03=>(T{b$s>oqC>bLk<mT9*Yd(Zg?Ebm)7s<7cq@YOa{DexY zj?3h{rgUDb6Kw!xr}kakrFg#&K+P-$x-~qH8lD5hP6HV7+~6h6huDY5M~A7NA4n7E z#2;|7xEMwB*SeZT1$vCUakr()r(DeZ-kVd}9RjZ@P>g%wS};U=D(|B}dFh#ICb@zZ zEytxUurAAdW8vgf77vbBaO88bR^Y0F)V?qpKux6=cS`s?mkf&t z`aQG7+$y|QO%K*P7{vk(>X2VTDNl*o1Zv0HRkED-)cG-c!1tmnSj-MHG-CIkY``T_ z7g}%m@jg17;dW}}G9B^s z%!6X(W=*EM`D(^9ZMXaL(pxATliTX8nHOj7)i_aqda6>`O+IRaXv?&zbh%e%LXnoZ zBFEHnlybm{5F17d3m`2x=YO$yJlFuB~7qv-i)-jVI5f)`Ou^Nx}dt$U2B zg<29-8R|ifQf2Xx={~ta_K!EV$v8CcnBtNjO8acc4|~i4V2LV6@EBd+juIZQlze5< zKyghqMcec@*B5_H7s&l4yK5PftqevzhM!0 zH>0BQ_aD7l^@~ho%q25L4QpwM(Dlu=eF!|~`pt(6PnJ?h`mfm(gDWeCBvqJ<0~Ih) zWR^E>{K(QTzc%V$V>Z+z)+yjA^esf@DGuOYCMA+mu`bFVF93BW(-ctbojG0aL|KH z9W@?_H90~~f;}R;zveS+_cN^4ZP(Shne^g?sk|AS-s9tCc(1Jw3!e%aC9w7l%)?t7 zkjz@1QR2Xwmp(Oz!$jl0PgO<*;llYdNPU$Z{kMJ{ga<13t@_B3)Bd=~FDKqlH5IXZ+|F>l57x=kow>o5 zWuKo2D2wKbb$(z<=;ae)|7bsE^HiNiYu4bkGFQe zbF+OZaXqJ?7ON-3k{*fC!Nh#zskmVfFPA%{cU`5lhaYb-YNbM$*&##Muc0}L6>is{ zeAj>XN-g79z+IS$IaF?E8|y8EA|;DDWbe3i>D`1quCP&Lf6 zX6%&_%U;>29nRCDv-e|wLXGxPV9q{BVIaYmNcnzIxKpS{axp|BonnNbBFZRaO-T`0 zB>5%q;sE~DrdQcnj{i1Hfr$pgs&s);>ZU(+Er6&M&ctCv?HmAi@z~S=T~Mnwc9?+A zM;`xu-x~|dh5!3`6@%TV6j$R#$u ziC37q(KE5z&_M);5lBd^319=*fF`}MTbBY;(QfSF!E9mq&lb#BAO&=Rkm?ED`93rz zv;0=#I5*%C*nk=^Bhajw`A;vv>6yIwv3do3-8Ncp-2_$lmSoIwpNjvS*jpTn!Fpy6 zXR4{A7~)R%-x)@#jlaf3w#tB?0$c5ED$dJAmN{Pl3v5_BVweGf10eOQ=1(BLt}AhC z-vFblfU=yPgeiFY&IWrxMRy5mgv;%8zy#WVc3a!v0*P z40aOf)NS?&ei=^mM3QZx)HN7%v>RZixvoJBbe;;ut!o>@Bgr_rfI|Mzv-^n$4xc6U zht1~*lrenPO;`6NgMmzJmHd=V(j13Y-V@JRhooai;}Yb;!0U3QU$W3G&HzsUqNbL? zg4x1KAPUR$5k~7}C?G;ZL;1b>%e2cP6X@UPv*LZMzU@BeDB-m#voz$!XYqwCp`aMz zOD-*OlC=*q>V~?KYa?E!=?UISf?A79eol2vo82Bz<9FXMo-4QEO%L|R+txt{)A{Lu z9Q`lF)~%rvY`CBb-u)3X`Z7lcF&>i^7Z?Pjc9i*j#vidI)m8xCD0y$2enk&Klzg$A z@Ix6=(so+CJ6n_}*s+G&Cu{Aq*jU{zZ^Y(tp3*PnKZlaRjwb-}>Yf0KWp94jZY}h% zpn2IK>x1>mlz>R9%Dj5_P*O= zVs5uGfe}$!uR6(9Mr8*bQn-VIx@P{s+xG?@f<5`bK+F>4so>6IUp=U!DOaw=eXPd|3g0=LHL ziteX1pCrqLiFI0Q<-C4A`+Nv3`&Gg)Obq-1(ZS%ebr%y zSs$g9#P84UbnMG50PBn$-}js@znYcwt&lCElL9tOY;R7Ck#r_4AnBb|8((r#U|LDG zO(8*0>)p=E*5{K2B3rl&&6UX=e9pah4h~ug@4Pa?df8+bNu%#T&61x2lY1I`8xuSP zx#M~#M!PruewT$T6Ls6N)oO+yA(~5$9}6yTU7&y8KdVFz&6V7;O(~y;_b%7x+YF3E zC`10CVTuN$t2ClsficJ>m-kN60@#3e>nvtPBXddoa_mASlZ2s)zm-;ueJX7Zuj9U& zi22r~OU2UvFB%pxG%PNSZK|M<4V_36>XRiPGv}N7$rKALa@ZFJnrM+Wz42QwC788( z8wC1WaXJ%8QsU83O+vxVDBx%C5X{nLM*v1v-e! zO3RE1rL2g74z_pVz=o-O)&@Dd+0bwCemB3r_A5!;ilTde?k)tn&J8w;co1-v7}R~` zo+culqJ;SHF@kwEf@*6rG#I26pTUoY?V$A5CQDMcY z%;lj)G!8QP&I70IL9r~{(}aXWbL^w}*K*T{-x)|IwHjl)&y^gOqlY2Fp<;gCx&wpS+M!)qf0I zE~|M{+ZZS+O?}^<7B0($4;tSp1={cw9gq|E=5v62lI6RFPmaW%PUB&cyBDWtzvbCs zhak%c6kA0sf%*D1PvkQB<|b3?E_o}n8v>r=V*r7EW3j$(CTDED9dni;eKsv z#Z2w4UH_INNV^lty;}}>a`=5+1*N5rrw%oU`}mP4T>(Kv)PF+^g@g*UiD8O8&2boQ ztGPeWxzS}6*}q)MZrCu#C-TZHyy%jxTX#ujYrn=%K*mEwPRgRe7o_dG)H3rbS5K0^1G_S+HVy72X5;k^jU=- z;s7`NS?;eJ#`;VKJn2up=ZXl0)#a?O?|YqtIkk7(orn&|#NEGlRI=jL5z6^HG!~o} zL~YD*r==Vs7^NOX-Qm*X=GckO;fuYmgNRE!-ww>64E6=K=1_HGtr!}pRNxjljXrOj zr`h)9aYutCac+b;n_AMR%_mDzT^TS9KAIX5!LALEZvE(W-?hxJ3PRY&#(1smIrC&; zw%Pdv_h%4a$@7FKt$nd-z`WKWXoZmwgvD5+k4ZA+E=x=EsROR5v3j7i)>q>e({H+8wCTxFl}f~4c?v02v|G5) z&4O;&I*OxM&YIa-5xnekRB6G}uD-)4QbxG(U)H6>teeIwj9hviY1@~3UCBHlN8m~k z`Yz%U%i{(BAW3NIv8p;ZN(4D{@)hl~Z~0O8`xU(H_`R*K#=rV?*zo4+_?3j*bKvl} z9$P03E$kdXfmyVkfWvS8yz1xYx3<$SXIy{(^7>4NQszVIO5g7ke%TQ5PqwiK-sTzk z3Q$D8^@J}VnYmmx@%yi(of@S;p;upp2-l{j%)hwkYCy{+pRZ~su!ka-bPkqIIs&>e zwoM^h;;!}XeACyDSDE5rn3z&1FGtGsZSb@XB zy5Bl4-T$JMl5-16utTvehfs%B9iO$q6c%L4UbweSLQ z>6YkH-jq42no^T08vN(?B0EfsruN&$6U;@RHS6Aod#@jGs)Xzi$Zk`4iE?LbasAUw z6GV=onj)DSDd{uDXDk^=!P>4RC(~uT^K247rPR>GUW@L|jF}|I(TESJ{b$?9CTNYR5Vm9F8YKQ;Mh+>p0Pho!C|wEQECl;{aZDZM@7nE^zA+ zsPmYu7of?N3;NU==q}Jeey+9RYkHQy{ zC+$T;*?U-cAbz0#+>U@vB~#raeS#f5legWM-9*P#yR$yS1lvf^uM(EU0k^JHmU}BO zBBW@;d}g7GbJ1u5AH+N&@-$Q@wSL3opxR!6hU@69!l|VUGO-8u8Q12C9(?OffDF6a zizqj0AQ&b6V&#Wfj)Pz7^qlL?`!>vjs#dMYfI<0oeET%n)G`UFKiPh{82!Q%6LPJ$ zVlGt%n~`NI4lm<*>YN#ST>ApE|NAFbjZmA^Li2KUMMo1xu-&zG1!ZFC72jB!qxy@s zy(S>ii(^=Ah(<2?i+?A>dn9-;kRoojEaoOS}`}n5g+#Q}| z#!Gi`^QUav^uPCk1zSp(B9RBP-EBi}w)tSt%}%5ZFvyPcsahvpVi9apJb<92xwj29 z`Ff>~hJ1e0e8{uaq)$*0M94VNf4}>Rlt#$0@MbMwW?ycTOg3Sg+h+64$i>Hau^BO? zUtKLnwkJz$9K$YjzawnLCtS9A0-@z$jg7>>jmDs$fhz(!tY}No*cl(`wgP>l^}S7Q z1%$7kg!=GJmxmOFT43%sQQ;qBo*mrj=+RtqLpxJ z_+UOPGzP_4qEkz7Bx=E4ARk@u=!}qpY1xLA9!n0@|ww%1ogQcb?VPWW<=`+HfK~7bD?Q2Kj(A>ky@dJuT#uPDNTk0+hEjE z^Cbj+3sG+J7I**nWpu{;Oc5$S%*+QsE|c`aiX4sPS`_jY2Y(hFhbP3gKVv2Ucys%i z?z0R4Mxx-jWz}ujG*F5HV3XMod=FK_*p8(<=^3xN9iAg{O`hw}$}M;jNERs>Nd5VW zqo|8gXX#z4&rNW6;a5Y(LVysi{DTk@gH*8%hAcJXAZBmm%f3L)_*@u0B;uHoaYMxx8BEd3gPu!$5yw6c?JI zt*`gAHHL-Nd=OMrQs?;P_6uBQ5z1Lo$wqh)wOt$VG~T3Qb({||)MUO#={Uy`^r>kF z0#7xfHI3CF`R$;ZapTl()BU3qS9)5gtjmKNMQzv+9o^rH6cl`I8B z+odwtiEZ+M=sHb2Am0sEw;^wWjwO% z*VK-T*M35_{33%~rR$+&Y`Pek#5O|DrR9`EF8;T1cnXp0g^+J9)Bu~tPFs{id zD2H~-@)IxF+SD}8L%VT;fuDCjg!SubKb^9@u44Dg4TTh;8@_m&XuOQujclWXK_gqH z8v|6`qjLid5kRmH&V*GPjmDjlmSFse65*uw(h*+~Zq=sWjq06Vu^-|*$X`f(-b zAkpiGN)20#Z9{qAS-oo17Y1wiMVv=N6G%DaD`kD2j`E@7fYooQ>Zmm#=Fl`w;Wok2 zFcNL2VYG%TubU<_ZHS=#!o4YWdPeb|7O^f+;k85dz zs?@Dlw`(X5oQz$@-xSli`meUIe(TL8fT@uWe%rYZU3oNa3yyR=CAC3;K4L%&_KZ$v zmG3iwn%W~~cs1j{H!x;4{{ql!Dy&6@P8I1hTd(ZKcFcD5!5r7rMS(d%P$Bh?IF3q^ z?v+Il+NYTCU(OQ?Ma`+2@5Zi@@MI>SA#ppm0TAmk66N#icY3(`JC1}uqFo*;^6(KA z-RXT0VinR`7bqg~qH)$r#7+F+{lX{RoLKC z|L|3Ha7P!~WL&CCod%^@4+R0>_r5%#As_2)_zS59NU_2Fw~ts`M2t75o7 zu1SCe8V?37P3rPIm3cyyW-^f5OPc#-tJTuvt*Lh)L+hld9XNDq=^!44h zM4$m{Yy6}z_eRF;Rl)Wrib3mJKMj(tfizZq@2hywv-u<#Ec)30&=}<}Wrt$c2PfTK@3+)@6o!$yQ6Dy&I}iWFW_;o^F^- z=$i$R2(gjC_BC9+Y-do5WH*lc6a~_!Vq#SPhy0uj=z;zNAQOX3aetpUK{)iIk?1Rl z4@rFMN>S8$hc~bQkP7*+nN>S!1uC^ql@a7R1v0}0eb)uMfvT}~TrLcNj|`D{pX<EQW>UY=f9XN`eq zG3*PX?*3Bv?m4bq<>1UZ1sZzgtWQ&g=Lo#Rf*NYl^;pr^A42iAWqoj){GtJ3wD_GU zWw_U&$);L%l=tkd2BNKOeYs(qzHg8tDw2{nBsF6WcYUB7Z@W|G^4Ah@Q;_hII}MTo zHl1PYn^_hgqmDXrxOoT5%)*-ojFKdw7g?vCf@IaD(`1!$b0Kaukdc?*Mwm_axJ~je zzHq$%yZjV39K5^nahi(TKqhH+YI{%L1}Q5j+CLE79Vc~Y>>ZuJrWkMf!fTf^Vwfsp$b-e z_P*TkHsw9x6(x{BY%kimA%H?sBR^1`VupvzfAc_;p2sHqIP_Et-VvXqWQ37){h4!P zyB96m3tlUyFR%_k8l7#1Nh%6C`U;%hmRDHK5n8LK-4oVh6$Uj65n7NnIjs_5qATJk zJA`=P(LrY#FQZLwHuVp^7a&dTxivDNW-+3b$V4*3q5Pa0irO$rDO#*nn7wy(9acz_ zSqXwSu+G1VAXuao_*hDK8S*5)u~<}M_nAI)yUVB1KnHR10IvCL+t>eMbc)Z;b>dB= z#^{!LU%fbIifAOmkjG*usD(KE8|Jm)4>4L>`jS-m{F+bdy`hJQ6UH}f^o!O>o)5dQ zDIi7#fMTTgy6fbgPZUK@hapyJEC_;0&Fd2M{?<)`a3RXgHEm${=LhI3InSn=f^Sm) zF#5hH7%GUrt^Kiy;N9ZhH3!@|s!ujprb z+}@w}=zo#F7itrd74 z2&_P10pu_EVF&mJJBP#|7Q0Y{+9}1rP`ZACfD&nyh^DXIGg9amN-3{FF3m`y$?qm6 z`N)KI6&F-Z>Ow+7n2s>b(sYvEeCqIco3w#Wq}{>MTL~)hHCo8NhsY;bcZkxJ6|2tl zJnb7#GGmH;wsnA01-&e@Yte!hLv0C^ct4~y&NJpc&*T%O|9efz#ZT#x^`mTma6#BU_C#rDV!q+Q>rwK$UOAcnFFj=t;4*sUv4)O5sM=eX&0-4onvR@DYD+(t&kSZM; zW3DYJ+_OBK&%wWZxuw%p;v`fApx!3(qmMnXxT*V0weX|gRnm|j z(s2s$#=AQ2%Nrfu1H1VmIHL^ICu;}*Vm>V=o@B+ZLn(GXJU%yy#ktb_0|Wq|zd^(; z{*?;rqe9~@mX+qEN)d1fP#ydCOAb&^QCZhySfvdV^18xgCh>9evTvpLzy-PkM=U9o z+%Myu9?p4yNEeLI3)^dyc?j+aSC%@`*5zeyZ9O+Moo7v4R_oaB%3M{U;wPxf3(6Em zuO#&Do36p5>}3iyfzQ720>N8^=Q+Ly(x~uE%!xJbRjDR<;L3GA0F`Lr=?{v(S()^_ zH{lVtlE7q)olZyF%{j-ryj*Gmx<#6o^6dJBeDAwvK|OIb>3nkD-l6~fbM5?yP)QNB z5vvfJURCjL!zdJw8ZEPNYX_L$GLNa|M6UQjhFR30p0*UfosFKY{iG%FWyP{0(>Chx zav2YEF%xdd85K>%i=*Y0QjE{Le0+UGImd~<7KIyWShZwSs*!%%eg5FuT+{Z;yvXh! zlkb%6Q`Mf?=pp8+G@uCEqR|`&cv%BZDi1OR=4C6}v<16Z9?)OGfvIc=dl-WN?}f}w zzy5kbUH?o9?zCVJ z8f{S%t5|jI&c4ZzYy)XYi+jN-i1wg;eTV3FO}vo;;ymBF?qmc{gASjTr$Qp=BIz&K zeQsm_;exe;2n;F(Dgr(R_2VVaI)9QpbxJCNEG7~8?!DQ6&5+K=eubRiN-S);tizxSx zvlICB=C6N}EO1oqp`N|6Nsg{@r~lY8}PbA2^`a|k|#{Pm|7z<-MB(dAt%qk!Y{%lonN zuga+V-}#tm)X@}?WMl#<>!W}{b5mWsUhp7SF5ju?ycyGay^4Ir#;F=3ZT*%ul59_| z1>X+6U@_GkY>`R+n2w5Sp@J5OAA}l{du;V2V?V+N(+IkUc1VvgRpADOueYNpIYeF< zaGnR@dgY^D?}tY&Li6fWu`}VnUf4K11>&%k-GU~$72!EbLi1zyb^P4XveJ%&y4=gG zO=Vx5iKvTK!D<>w1GDS@&+E!%HV2kuX$CZ#s}jqm(Dz)2H|gnhoMN(xe2b0lBF?*& z^$z%enCnQt?ISt1E}i*XF5{!ZEfo{XB!Y^Bn44&hbl7@b$UnW`?ykyN}L;|L*r?38sop7h$q5BFe7yqm+)LHB@BD%5XPeQwae z|Kh8CS!OCQA~BFl=9Z;(YIl$IHRY>KESDeb>gMH#xavIJ7rN%>SMb6HWC{(dK>fz5 z9l#te-oFQn2)~Qhx%8e7cgZ%nwU6NvG~Q`9e3|)6=Y#ueY7iqK3BrU1FBglk&ZG1i zb0XE$^}9W)C;A=ES$te}{grWo_ijrMYvA-O+eo8K$<5o>Y0Y~K2tK}NEQTzm7-KBr zAQ$t^hpq2RajZS3-LR8XkUVvu$53fUH_t#2BsZmQV+sXXzUwpMmsdu^%fuhdZf`b_ zXoTJhhAr0VTkpP~bL`Zs!RlST7d}*J8G&@(@40{H#$D`z>nXC?LqE;3fA26I@tzxW zDztv@?FJ^uQPC#`C54#2-eP(HOQ2zm6N_5MlU+tw5FNkgbN~1Tw(2)s%?)ntr|E+E zVvToC4%GGuZrf)z^`q*$VS}55%jDYxKk=$3-bsyX(Lb%;Bp(=7=?B6MYk(uvEIF~l zIiv-0qL9X$GB%V9%uimkjWWnOLMfJGgwb=@FAAqAo%O38pBP=NvW}_sougu(&g5HF zEFG9ckVl{yBu8XKz%kVsCsf#Ivo;C_aLQuJz5m9rd1+AS`{B=qZ``DjbKSEJhWUBPnEmG5Da%=AzYkP8xBdQjE0x!*Ay_E< zx>Z|QB}#yh+eARvh)(srTJE`Rj9WqHIK-5J-hYhvF${k!zLN32REupX9QO;^{jW|) ze`|-Ib>1zm{n>e~l6GU>>zJ~Nb$7POt#jv2l&aCckNJzk0NK|oN5Ri^CYQYdk~RO< z`54Qx35MA#i({wDHXHooe`9rDkr7vQ%BHkI9$5RbHB7WNXuT+z#dN4h7so5nD30x{ zCtdfrmnN5b!cDg^59HV!|MLFon!@;Xlt66h@q~Oy=F9cgDTe2|@1CCIXHLx-$cUYo zGcM#KTrMMPF7zX7xjZRgVx^KQC`s?f1u;~0Z6yH*qs1UCd!(M^YGLr&K+aFv=anO0 z<9!=wP-cy}cs&svTqz-sMQ-E{v6`U6-5a58dP)%)&YsjSf~7c5yb1tD|9BO!i?)aX zqd@pqLlQw6Qg1aGSCtyKyLugF;bBi?*Xb6%97^0NGo`eA=HvVE!tbS>^9C+sky@W^ zv~x^>`$)}z(N+MMX*D@bCbrvm-nuLCg2JVFgPu=OM4expP)NUdPV`3E!v;NI6j_yr zSr$1dO0SyU@}|uPcjp);MvX%-M7EZ~38(tEB)f+#@RKZF-re{J4IANm;KoA|Fy4VWS>$|sZl6?7C%|P%Nz4uC63^8 zKzQKXJ>udErf(iHz+$Z1>+yr_lQO`wLHA?xo&xSlnr(PF_1)3s6mCBP4PN+mzEZMYmlKbQa?Vq0B;A@4! z@!%6>Xb{Xd+ZPDG+hjLHM$o9$j1NRb$6g2Hdfk8pF@Ebu(x!^gc?vC61LgyYIIjkh z&2Js%`ky>H|MDYE0sV;xYF#}QK?fb{aq2RRS!xixv{%CeuxDIn-P`eck4vnIvCa8% zk=ifra0%IcJO{0^X#+pBlLY#M`GP5L7*m-vcq)Jn-0?jb%v>0lX40v3F{#~}jc$)) z?-IG-uvq8a>Y%yb+Q&Wc9`%lvgN2pMoC+E*OpEXfI>%N1aFxNb8?Sj&qWS z*At8jSXxnE`1%f@A6^=(?K>6bx5L81o`CA}P*b2Zm_>_}@LFL4DH!mxLCxSlleq8R z$d>TB3}yh@!Bk{ehM;W$h(cX_EA&eZ^3~FAi1}`njqa_S+KF9vn;<)kxIDEvR8Ldzvnr>YL17z(Hq zuJxY-+-#cJRAlVxGHo$e0p#*tV*)==M7zf{>Ay{H(~}tvDglu`b$@0G8vVKKK@rEW z)`fn)FR#3tbbFy6r&u#5q^YTi<(tZ5HXSy9Fh-az0J{k0L}SX|?OjaezN`4=!5q1G zxk-Zv0s;*I4+^yMq*U|33|UKp^idQTBH;cMaCG1kwNt17k0{YArXj6Mx;A`4U4veF zyH5o-r86A#hooMk5-g7I7~=Xhb(aq(B0%~o2DkRY|397m-?Io${tvSVS2+JQi|`uw zTJENi8t2mL&u8PLpp*Lj**vL&OIz&L`|wL{s}r|`oJa21TLE@bNe48tj89J@)$S6g3xVNc?+X(JLHBhl3j1vw<3<>;Z*rW0LcH9nMo=1c@ zM+{s0nuBA4WJikg%*eyN3*x`Xn^V?Xb&70oe*f6GHi%2FIA11y=euG}e3l8q_exgO zu|_66SNZ0pK>72N1ursaO0Z7-6V?ChpQxHBE6HaT%bVK2btR`}+~fbujS!M&Lt75k zf}s)AH+ z^{>a$OBmS8R?~B6(k+cQPzD~X%B8*8Evc2sLN=$x0Y&eXVOu zzjUu6`odKbQB%L)^g8hc7 z@p5+2fT!g{f9QJOk6G1ZE-95bgXM;4_+XAO^}z{RJ6P){W9$A2hkaHmym#IV)MndW z)7EMu+ta8figAA4os~jo5<=i2vLL5nIY0sRl74*Rvn0s&UTZ-SRy=I_jSnf0NCmjx zF$CnnD)kg1FVm;w2(Z+DSbDhk`$y_E(wb7=?Uj5qXIK=Ra__?{3QN(;{fHstjevt4 z6k(Nj@G6*m?#z>x28VcB=Iiq1ISB>Ny{BVj5}r;`Fd8t32HMQPdjz z;90GUvC?{DUV%{_^YhQoi;%Ilzal(#p6R(mOsnX3F9tljdF$kOkf69+NxVtIoF|Rl zY2yAgOVSTW_0ehXU7KR%yQzNG&g$im6VZZP2j(U<&$222{@(uY@V8A-hAw~w;c1OZ zQ5l|oe^mcqE@lI1PAGvz>{!yaXU*I>LyCR4(9de!ovxF_uVrSCWKF9i>~S#l&csef z`kb%U#Ko%L1<;(0=b#6|Vc=NSqaNU&RBNcr{`dDVx|^OC`f{a}c6Ak!4rUYXR0&z) z0H}-vWBMqhB=`8r9X;z!eBGK(UBqT*gp_|L5HA)BUISL9K~wm(Ycm_0o5H_6@Ap;! zR0&FqW!(_?@=Q$X_0osFbDw5gW#9sD^=`B7>Mf@kzI|NU~H*^*mmJ-F7K9GrhYs_pc=1ZgUm|bor;c>(( zlN${NP^hO;^H zS3yRoCYmH{4i9iJ`D4&jI`d_nm6KANh%8|DYJY*!-JAdb@japMyytFyNcuYSV)Pm1 zbp%)Z-T->Ght|9mVUjMnFPc48pAz}_IfKri3;4~U<^x>>Q93#A70`_|ha>a97<=ou zsJgaoSW*xWl~hVW1f-=wLOP@s7*b;BMoLmiNkJUapc`c9mW~0WTN>%^{?>q(&-Fg< zeZRl&|Czn^UT4>F&d=f~6SIQx{KH1dpA)wKlO2mGe2o@F?1fm_ZgQYMfe?eN_`TG2 z^bE;GX1FDqBqm_j|$H|_(YnAjP#HfPRb8uplrEJGm&_}a3a zHSTMey@XjIg#zmLLEv;lxIj#@ElzI+A<5O9rSJHE=fD0l)7`}aMZgE`pjEfyk=tx) zT3GE3>}L(jHS~~CssVEz(W*o2?^Zi{q{EdogVuZ@Ms2FpvQhN<;-iq}D}`{_Z+3=v zj(bdok{bNrES;f1f5lq6+>;=HhXp?FiMSNH3mvilmnv99uX5#vWXI08gi|Jo{BFfZ zZ04)r5Tr9Zx$?FS+A+3m16PFIYlMGQ@HM&OS3#+Hds?NMev?2Fo-hP?=<%~FTTgk# z`d7TKT$BBrF-il!*IosB6MS(^31>TDfv~*_ovxm5R^oD3k;_N%y8Jh zOOefii8!rjMX~qPN7^7S{O82~72N)>6W@Bfyk!ORqu5PDLqirNJ`VoXFPA}%$aN(T zEu*g@8t#~IxKs9(}CVzV4)1?@5s$6BbUj~8HZ|Q^C{V>$GX^ad!ZO z9P%E~bz#g-SzqRoiLsiW#|D29^Zaq_CtP2vFXboO$N!1#B;Z-Qn!8%XRWg0{zlTlr zx=){ZC#^npbnf3$!(&Xe3*M(!O$^#W@jLMak{Zc`gKpBLCt3w&WkDv#s4WK56>jo> zCmZ@GB-9X7ml;^OxcR;YP-5#cE;7WJ*y0F!NJ2sjhHE(c>E}ZHaQ@L~$FZTZWsJ~&r z0gdN7U;3AiK(nJl9t(Y?A~FxtYH!4Ws0~>P0PYyg+yG<9(xe1O^A}6$%L)*O_?3dn zBZgF?UIskgO zQywAY*E_WPmpL@!rfQEiLfr$&AA`T1c}K3^%(5AKof;N?F?ns?sD`B)Gtub0D2CI4 z;RfNkw%NvDr!#)26Ti-3YJn-DqFexM@Glfc!kbstrDQ_^StXJ!p)#TdkTO_-4hEwM#&hO zlY5b)*)m^dNwxd8XXZ3%`Ko88Sl0k3gD)_s9d8FnkWo`00DzQuT))}NaOYw?&@FT4 zb;E)V1Kb*mrv4E$I42o1=ZRb*oej+{t*T4tEjq(PHWb}$ea1p+y9+cz3c(C!y z?(hW|>5rj{quP3_?!7rq^XM5SuZgUK0|t3bmoX;#Qcwe zu@4X5&r)3jC_iDm1+%<+tLtR|aOP&#V?grO(F*t7O_UqPGBa%N+Bl+Evdq1D%nl^NmZ%vRN6dnL_0E#0IbjJ@}4mX8FMMWoYv@IW`FQf!~GCDeMa2ECs z|Ftc*?lJ?q8aN0#oLSE`Hfe-pqkiO8?OAK58IwVcW8pa)i-FscF-aIK z>cN)c;pfMD7KcO~-qs%e-z2Nkvas8bS#$G#e=`x_BR$fb*%j%~L2tY2yiWs8hovujU)Sim-Azw*3ibe0PIvCEX zZ=`x)+@TSeTQ>??7p=wW*eEVVU}gTTQe-ImHSJuW@(<<}SBsyXlnx9REZowqb@lnV zy&$@C{F1rUD&amzObaq2>gwt=PEV-&^VCQ{I>%mbhRYOu6^eLnylru75U*~H``e%P zDIzuRzUqTq^yyO@+H=iH@&Nqxy3Ok-#Wq30rBp4zr7bXA zkqn9B#o2pIA~n`ktvpzpwQR}&#$w{~P(a?hA^@)RBSD^9_leHs-OO#z36p#Pm3!^c zN=(354M?v5R9~gF%5F5;4T0K`yTpM7_j6-s#?(JQt)u>pfy#=NM)X#_8)>9W6z9W*Z!YIU7;8qjWWESC#*zQuFA zXBtp-gp+JNtCsaO27Y=oQ3H*~as3IofFm)|HT;chrj+^I4pgs1xlE8XX_rvK{L)71 zzuA`#1TXp(ap!DLzKCTEWYOhO9HI&AQ>qxb0;h&vC!P7;0ItYNGuQvz>O4VuSfc9B z{o(IVu)HwxT4tcZn~26((;Y`S^r&24&4GTZnq(r-fUfqC%d%yF+5%^K%-YI!`%oWB z3?ctIS|O&$6mv#EoICywn2YSNDEKMy^m1)U(lUw8Ct-CaRa`!}(GTGTR>P*_`~A6^ zH*x6j1Hu(w3{)X)`81bk!UQ3j$;Z&oRzSew=4+Fr+4aC&k)IY7(s)97LcqxN1F@dv z#a{({H|Dzh^@X2HHN>5pjySS~VX{be<=h=_*V<@?ngi{}H}?`tHj4F~A4l|k;L4kB zy(B=-8wtDlYplN3QsPDS3o#qZ+xt=w?&psuFlNjnh?uo&o$l6P^d@VI6nA5KBuR*M zzUhpLK5O)T;RfVW>P#bbQ^;#e{{sd_pG4y4OXG)Qs5k{b@{WQ@<<|JAsxxna`!s;lcFGgY7H1aDyfZLBT@z^;1Z{ znHn0!*lWrD44VrgwNn=3Pai9f^wbhM@0!3q)220?nFp%xdtxbUPpvCx&5&PfD zXXtQ=!FG7nMXSu=fVGbs9XF#K^EuY|Tf_9OM7PRE?S`y6h1h3O7rut^B-y{Md@+Fw zb9<)995>2n2MVH~$!dq-?Jw_Ll-+)1(cW5#;aTSoMXBZobEXNV<4M)#_kjsD2P0hx zGRRU9FyRHbv!Kt_|4jsI(k{{rZVnKRkg$F+tj~Css7QP6HWb!b=CR|y&T&T#HXmJA z7Z|(mfX)(oqt*ap3voJGFu#G;F1+!6Wo4sz3Y4Z|gi2~rq+)g+0ByUVHBTobE>iV_ zgpe?0{P)TVVr@zYy+gz+KP^nS%!>qt)_ywNN~)*U#4SbPBW13@zQ~&Q&z|H-+ZQ=m z=S_6&@VA)3WTFws_XM#cRibn(YJ^a=*T}UHGWH%k-vo*^_nE>*1Y^tD^o}=}n)P8eUb%5RC;oS~$o@>9Zhc@^w1%Edf)IeBD zJydlW2xIgVVnXS96z?XUkvm&mYwZWE?UDrb+E3AiO<#Y06%3^^IrS%HpQxspeoQ;U z-xUj0$9CIyc)~T-2c|6>Ut}oXJK7-@U%POsIkf0|^E`fsgShJu51ZhjD1dZ~c;u-0 zvmPBlNxIyaz!w4aqh0CHtdm69PR#ZSp5kcBh+;TXu+_{BbxZ%*^q+o0#veyNZO(qT z`H=lu5O;B+S%ozeM7+hbL0`S(ydy0`Bfd#c2s=7AMl%Vd~B&=Iyw7m~m@YDSpObMiAM2Jf=Q zObL>%ugVHrT1DoJNY8VR0Q6KaBH)Mb4-c1i^eNeu>kp;RWS;ny6IX_1C`9?=fj7dd z+!;-yZy%JYj^n2ssieb;E3VsWIN4g>W9+x?9*4dH5hZ-}i9xuh67UJ0!8{m=RreuX z3GBi3Dn2Z@=U?8)aS!J=I@WLtkU5ZSk{z1)Wm(9foPG{vqFQjxq zE^dyTR&Hr9&nW|SyBjuAAV3w_(7*IGYiE>m0d17B&(N6uI@g~4r`6! z)57}DM;A`*824KG3-u=hX?wsV5O#C1dmmH_Lt#)9$Gp7ZfQY}Y-lejK4W7luxH8od1@iD4`ayM{qQEe zq5pFYb}u}G#=u1E_zX^2NZ1Xo(67oWM#xYTbQAnoIHd_R%=5|5zo-6bu$0SkP;zT= zzDs)8)%a+SOwWiJf{#{FDlvUp&Ti04USz67vF3)~u=E)b!K1F$l+fAloltL)wiV~d z1bjCOqkgLBNPbE{jP7oY^S8OsYKCO5~oE8%MR^?@o!R~jv3{BilOiOA?TC5!ywqCOFL@HFbdYs)Mw51<|RGw zRfpB@oQ3W#Lw4kR>ojfKt1=EfIaW0N`M*9?>WlL)Z&L;mKm5IGWeeKOYHDnh@@18^ zYwd5hMW%~KrnHF-vP3g!DJQAs81d~clI?w*r$u>L$NUmT9PZY$${fjh zTPQ7dd4o(!>t}8}Z}vC*ze*`etb}#mX4Q2+?ViNxlrL(5lY;#OzVGuTBRuwZ*{e#e z-W}>iyZhaGnw`_U(#{Nbg3Xt1oKEpM8MiV9c}*5yV66+kp|+xk7H~7>Qi@N*<00m;W45R6r`={wia-bY&a|Eg(#F{?ptbn7Mmg z=4d0X(8&#;PT7FGOxAjZldyiJG3zz5bviz*{?4g2pFQ; z&$qGC{)$y?QA(B6c^+GcYH$|cD`y>1T*wr#XYIK$*-)Hy{dWk}yGV&`#wN?Gg?U0< zRlMPEb++6i_ctG{s>x7zxm>Wc6SoKIMO%uj^4hUqM7J%D3=ud!E;m@>w;^OumfRvk zAVZ&~3_LRt@P?YO$!fQNf{&7tt#tWvo}3wdt+)KYI)QPSr|1_H-jUR9E6p!outX3b z&WA(mo8;s|CoA1>to;gKHJ@^2Xjg~L@W8Z;D8O{d3CyPZHuJK&)DU+lkps6X+NDj& zKG2|VdqcqEBs}?YY)KX(!@ZKiM9#nbiSuQcy z1!of4JaLj399r`wVdcJi*a8cj5O^#hcDE9?_ zvyAxyB_w;lN%sg0qlgum_7M7fLcUFqGFh*80KSck-pr-U&1Bw{w)du2>NP0h>ebw8-1NpZ-19~%6K#yVj*W_Lx!WpB;*-6ySAKS%YHOI#lO+Lj0p z1A0cVS^otUC>I67VW|cAZqbdLThc8RgD6$Smkm;fMn!gW8z!(TtnDqUo$wq9voM^2_a*2`wR85M*FY{+9L6BkuNXoJ|RBa{rCmH|v)CjcWn z%2#x|Vg#l?j1WnEvivogej$#bU!y6HbZdt0hyUl64=g_0hA}`wD*>=|t=y0`wdZ)* zu=63hy)hyXm4)v$f#W!!__>LBa5k5T-F|E&c5l6L~{Kbaj`R@xyyY0`kxDv zi6D|=Y2+5Gi=M1>=KbzW$DOP9{$>1ygL=O`L$2Fm!9m=07#0foY5e7YHaLICd3{FQjx4ITy$_DWIsM{ zS>TJ$p|QNaEE~+9_ekjSv5y?jMw6Jb5n%9yMe(Jks)|ul$2P#%5e=Q)O1@6Y|1IG4 zJ3VHiE{*6*$5X$0tjj|vqdx`5ntZ}LWJx7E635s~Tw$6Ep)nuk6Uh`!xDk8<1*oyd zXA=8&FTHdK{CAn2Lg()e_^AdcOe1@1fs301VM(Ad%P4px_vOw!Lzew5I?>=<;nV34tEp77!LrX;sId#aU?4tHd{Gc=xryxKDvx z$dh=&!kEUc9=&2dn3AH6q3TVDm<$!@5DsvkK^L;S10Zm_BfWK&K_Gx&>dCXZe9N!G zV5|i6P87)@&FV=4FHX3%kDdO8Ig`c~UCA0>AjZdl#225PNQaXtmem;ZwQ3f7A`UtubX^K77hJLPx~sGUjd>Phu!>?H61y7)8c=noJK*i8v@g--I9nC9k5$iG&Vbu z2-xtS?QX_c?6*M=s{Abx0NZm!kO=Z4Ih<6===e`hGdxEAw*=~&JKyC%s&plmYk3bi z+u+GN$l(uAS&_S&xBmp?<}i1G<;y^ z=)Q>0XEUsmUDe@faSHedBsPdZt_hPjK8N}_HxVs}oLr*#L*~Sg3YIVOmArlBm7-(% ztLV;mwlRX{a8ITpfVT>w_cuY4VwYaXWW~uffZHKjy=?s;e`$5qms{RKj_O1rm{^|O zgoEooP;&ZJQAQl%Zw+q&g=!~=)sFNv1d|dysV;Sckn>+tRpYY7WR$aUb~@!?%i4UC zs+HjMMaYzJ{=}=HNudl?uH8{1yHiK_!}k%KcA`XX95oEUOfD*`M1_W3JGqGR?&Z(p zl5d9(`i5Eo9tYPHe*22D{YkC+%l)WC*#nZ*&6HQKzS9AB;VD3$ZH$$>C8zg}wE#ES z6orJ4F-m>Zh7ZKRL9X$MNRJ@HP9-vw`llxiHOM@ue=HkEra-tdSB79$cLbm`{#e=? z_pPT!InS`_KZA9bJ>RISB;`|?vI!;F3lrE(T+OWP$IGaO-hM5&#bAh}ebzoFQ()Kg z@S%)R_+4W@KXuw7c>HauO;e1ybtpc1G!uSAAX@!<^L=fSdt1FN<4fMh zrt&8tfOi+Bfcl}c+9x7H692_<@G1u&z6p#*Zjk)o>X?GR$K+_-u|4RLH_PeW_WpKX z<^#dQ=P9FyS00s1DbQ{Oh8y}hqh)B}U!IB;ymRf&)lq$@7^lHCXly18Vs*ouC?i4o zLVPYP5J{TvUZip(J>uA0PPE9oBoBi*SI`^)n?8s&35mKue)pkO6AJ(JqwwMoC^ zHBfeCkZQfIJ?32Yb;GAhe%2Xe>ttV_Msc#WdxVoPdL*E@sY=r$WZ{#PgYtqq^DauN zqjurjv$5~vKsX2a?~E2HLe!!ep%9%fSw(3ntzi4$vs8l-We|e*TlC)O>gpOU-&15d zx56x;DDw`9Eo7ifzBCp>AQgY(g>_iGkQcI@zc>51#v`DkY2{Uw32Tx0T;3nsQGAb; z7QzU(RF^|0BVe;PthKC>!h0KU?*t?r_zIc63S4pI3l$?pazWn``E0|<_gaButTmj* z1?0(gohl$*NugP^l=H^`gJQCYa=97FxNrfH-K_iR55*kAu8$lzZyYT%|2dU3{}y!X z!B^h5iV2M^031eCM22E$mSP?mc2M>x)n2(1|2mpghc{T$8sq!IFkT7>#Ed!&R?6-` zA1lE?5WHzOM2TlHvqHT0fgfsry5?|4&x;3l%=7C53Rf$nSMhSk1Mtox`4T8PQtgLt zfnKw3mfBa+?CJj3(gUA`wUb)P*I*>_oEJt-TvIex_fQt!0H4%q&PzR3~={|h*5_zkX8%!VwHHL(n zW=sZL`GKGR_<_uD#jj86VR+dEJNd$P%Vqf6bF98!&Z?BY+uFUHcn0)UCDZlVuahSS zWsVCdPg3OkV%QBb>W3{ufku>V=@>CeML(*F#iGxh)vOKv%QIIBy)gYs>YNz@%U@|C=D&2Gtn&0b9;pBF zjFR1x7#V(gc?)B@guUNBY5g4Wy$~Q?v({YL*@GZOFgiN_0_Xjm#+Ki-svz$%NWa51 z>tlbz=0&r-GRUV08Dd=B2p!Vh+6(kQ_GrD-5aXbq4e57BQZs@Q+4h=;r!?q(1%n^LfJ8@I*Xo2gLzZI$*?0XRX+}1eDSmSWyed$Q4XROGRlecd8hGUk(P06_V@Yk zD5To(J^wqR@I& zdidd0D%>}*o~J_~1H7pDGa)*=nT{qbhq`L}yGF8KB}?6Vl4-yV_1P&!?gidBVdJwura0H*K~f6lQ(<*vDu z!k3nY7DeSWjc4fSm@_{jRFPRRr@RC|Y7Ff^JTNg;a6_4x>arnG#=T1V$L&k{b9_Yl zkeX7W@%xvy!F~UaQy)qVdQ|xm_&w3FZbUlO-tKfhxOFHaLEn4jPQ;PV=~OKL!7)2C9kL zV7T9tnIT%TTLKE&9T(6iZ}#`ZKg;H6RCJL^$VMq&X-(+Jr0_MFofqQK`S+3?X`E9O zyf0YD_8EeLrpb^=(vzib-UdshZ^OA}Shz>?*9@9ob9cH3QMLM{ z=d-0AL`Tc$Kq65T*zd&b3-W&dj$P$pce7joT_VtAzL8lv^2zPn?e}fBhPhr-42d>> z{#{F}sidUT7^{Eo{6R899OiLzT_QX`erBW3`2kouH~155e59>#e@R7@RhE3D{+Uqq z8qckae|!TOOTco58XsLT^xkB5`rOVr3-J=v##8$HjC z$%T*w1~COY-$p!XWGak~CW(f|aFvn&S@_TZR%Kj}E%WuY818=GV0RvQRAi&8dYr!( zUsS)?wsVVFld}o?8N~l#3o>{45xOuKYP;L-3V8mTPGJSog?yXQZvUOlfuitkQtKrq z@HK2+ew6lYy1TY#%e0NbBvev2H>6-Wd=ZR_o&(!H@Y83o>!ESPk5mf?XNQ+?uyAoj zLBO45#{+*r3nzbSZOvd}gg(M^2^Vz@`x+PlhX06w9!>bPxSVp%>;=YhqZm~?ynW{vhE&zD?Pyh^`9>@o0!Y5_NNKWqz|F8>7RaNfHO;PCSPA4lif`q_U zV)owO1ta^Hpcij{0)TkOc5Erw{HsRbBWTaxGccX(GF{ZgHlQIcaMV6{mRw-6mMc~F)s zSO%2 z>p$xky({&l`KC^ObB&%=mTre;M{n)t+O(F~8K$55 zOeTT$DCtksnFck+mLe9Ao1P|=AFY}f9iCr-aXCEzo9nNkoSzRL4SiL4&I1w^Aa_`o zY^A<5DF@39JjA~9IE)8+geJNbF?mWv_8n+8bWDsyUkf2n2U^wGk9$eHkrmApTvHYX z*$Ss8y96r%Xy^@bQ;3(u(&I)_S<|1>p!ROedUimZ$Krxzdr zg4v~4b2EB7)UIzqQM=ZW1exrLT%!4qsbO{c#F`BmFF$rjChOAz+2A0@j&a?t#n!?< zWEBAW7gh-6+n;>j*u>@pDTZ~*u@GRX9zuu(3}0A|-(@E27m*6by;gp!KZG!8M91I&%_)L4g@sA!*`u7P$)7Y@g zQ>iQyDnNdbG9V1R$8{CdVq6IR2#`2fy_KW-4&0UZM9IRfrxke3^$ZP4)e_ETIOW*7 z9$+HP<1=-l9*-o*O{rS5rt4LWT>aNu5Gfo%jL;*RMX>Msu==2-aDg>~8gd6z-U_?Cm_7YP|L=h}r@-jO)s%|*{H3nv~ z^;4u{HTg3d+eRS88cHyff}h*f0pSfsaZEDQGHFeMGItn+(Itz2FqpQ2DG%L%l^7!Q zPvKWb}HfxdwTqgZ-<_bRso-3mJWcP^K)I|ss^ zjVjNOTP)2M_e$0-H;8ZUw$b$4GX|ceuTRYtG>{4w3Bz*6F!E7;MbJ3wx)iNt@%l$W ze*gmu$^*C`p-ToBA6Oh^FH4R@Mv%=Jeq#!f)yHI|PYb$gMeMh5U0 z8&@T~ULAhZB^8+GA!O0jv>T^K`~c9aQ-j}CvE;TkNb)D5_BDwSZ0*FXDaau1JAh#@ z1}ZA8V_M^iw>!kaD)gzQopgHGNo`@=TIK#TP};rCVa(LTp3D>wy7~L_2^dHgbjpkK zTaD$QM<(TR9wO?-U7nZ3J_$6uQs8W8fI!{=jYz-k6mnQn8*R*<$P^G>V%Kt>p_@p^ z|Aw34%`&A)-c_bhJ*mtMmapLsq_A_MiZipGGuUg9KVr7F4evrz=hCq zI->}hF4l7HG3+MIi50lHXW?6|DjAhXcKYollAP65Rrk-gc6T{r9#&q#p4~^81-p|_ z4nftc;#07+vTgp5`0*sOS6(wHgtFN#@&cKm0p*DdCE8^=gfi@{mu~el%F^aUxS)L2 zT!u;>y8qsaL8W0n=!2k^HlW}rjpoSXEi}0j2?rn&XfE4_4`k=%f(g|1?kkMrkHIi7 zF_R?A-4}&j9%l9Dr1NTm5*HTpkoN$TkCbr z;n$57som|j>c5ja>Z}@)06#@ZJp5+Uov}2$>>+Y|a3eKXy09I9VybiuTjC7^)Eaeb zR+Ey|8J8>i?IJCf^Id;8zjxA~9^)a0TpE#`5MiC5=a*%0gV$f%b(=BlIuKg73(6x# zz@@upA3%lCa?*{Ih3|>s&$?OO>K$i*Hz$tFqrIYo5NnrLC&a>o|IoG@YkX%z zX!&a4)e@FVW_#97`@%6F|DZx%DMkG5OM_<+&|Lg9AShH|wyUR;*Yh{(qb}Wd0}jne z(tB!)eIsZf(9fzLWeOdZ2g*NoC8je{pZxk22T+1t7hjcbZXUG!65vUtiSA4NiB4x? zGGGAgE~pbpq(0KI-8dvB7j)E(oS!e0d|>1Eh_ts7O_M9=_3Al~r|vBxv7@KRoJM=_E+$OTV%DvAN0x7f@KN1(;UERhM3Xq4|p{_N^H*FIIclU+XJZ8-A%fDh!o4BeK6yYDo=lz}zf!DSzNYCVej zB2QP!u1g6c=yS@75*=47z)37HpO@Sc2|wKk zx6fB#wq0(#eadT{R^$Z0rB~-xedo#c{mEY>&834)JHB7IzB+P3h?RlHviO}9E5+>F zkZS40fgfRraS}0%2+*7ap zhxOY8Kv^}0k_flc2EoapuVEhD%1}PuS5F^}!80wg-i}G|KF4?&P^a{r9>JdyYpr8z zd^GusEokmJ@q=A{JS}9cOZg@)l&^h}>Ev57!hREuTVUuPqo1dg_zr5=_LfWv%(#)e z(SoWgD8z2JSpHOjp)B2s1k+aHr12m}K%<}3s2oOh^@RJ-BeBJIL+lO^f2it6_ger^ z+e9_r#ej&Kr>6_Yrb+>x2q8dd`6l0w2R$-mYg~N-RE)B{rYU_C6bEDFW1D_76T{J_oBmwKyCY5dO`fa7%hmF@BKGXy`J8 z&)eDe1L*(3kf;CQF-%kfL+l8Xc~fTBSi@w zBK0Nq*H?g&$5D^mPtU~>4YxD}w$6WsC&i)NzQ4eM!l8d;WDP8f;i8fK4-a8%=yL36 zKOC7iL;gbMKmDQw`HQ=n)fmWMNdJdlWCwh1G0JY~HD?A^0}PEe3>{EBNl}ExZYl$U z!DQ6#qJ<<;#r*4jQNaY+0xc;0mMNo#|=}%hSDBEN&CCiUS#~QU7-me@{F1eK&?!Y8c7TMSjnrT_^m`n`S|as{!0lwW{k;&T~WXgX5Dg=p>m2g_$Fh^9YBx`y^GltJfOsHYL)6cN`2%(m)B$Zs2GH2J|*N1^0 z=I3{$9@uO>BHi}5US$3|k6W#CEEy1iOYsSX0ElkT!Lp_+Ug72@1s)!gSMeKXbjVFy zTwe~6$&vDJikdaf z2ZcbtO>+V0ZS11=6^Tsobja8Fq9N||5PVhskX7 zsx{&a!n1P*ZX^jTh57SmdeUp)9kjPT2;7tUFb^s~5&T+QbFHx3!JZUhmXKb<#?>Yo_=Hn?5g3LZ8L;--i_XroX3JQ zO`Ns$kt_-RPEzt$eb6qwI{ z<}Iggt$%WSw>6qBKCfe7F$6qMjV8RK7t;j15&FO0$d`iqUK5^df1?tZ{6 zYf=VyzLmr9GZFfFd@4bScZYUhCzEBu4#X^T4@eN$*Ela)$ije#OT_yHR!{uB*H-cy zrO_X;Wn0DozuAme1a!Ery!r*l4kV*a%iODN6P4>x{xQt#SjC27slob zmne7+syJF4@)hu;1v)EIm@NFcUn8}`rOgKvNwlTcZppdz?Oh8{E=^zB$x7F8KB_3I zlOUxR<(VNRQWJ-^J+*oRRxdW}3FrQsB(lhJK8Z*s-37vjB{3-xKfQ`L!5XpznV)rU2^Pi2fkihzbVR5ep#d@`ctPSC^8dcUW-)Btmeu)r z3`Ir7P&bF1-xT`BN2g6IwhW`BJU`FhukdUb%oNdTJTXCkGM^i&mxe2;bFdA=;l~Ju zTb8GUzWVXMZ4kC*HE3uuuPYkA#Q#n=x29VqoQz8o+iX>2OS0~F6+ieusd9{c3uObDs5UBWc}sTC?wYg6AK z1*$qm&q8*W!|YRaK8bOD!}vd*0ZecrmlY3SK7YxBw{~vj88pD&c^-7(kKhbQ71X=- zWI^Q3vUv+a6ArsKW(o#cBbW{QQSR+NL^8Y5vRN;vloI4zSVeIl_}_9XgC`MSG@t{C z`JI7LciZS_Zb^<_Yt3T3y6EoKY!e>JK+c*-6zf`$_2)QAyf1GRk!A)n?d!6RD(t1y zJtv=kkw-rgs8NSGv}tgwP7Q%yE6DRcpQFWcq!mIp6BTafprl ztSQ15#cu8JEZzni;3@v)uq@9CJ{v*+L`TZDa8n(N)l?B z8Q%yW<9C%v?}Q*F(Om2sQ?~z;8Gk_Pzu)DH)C8{2>XN$(I!{w+RM>M^PrA~t`Ln;#P+aoL zbwS>RTF2OC^Zzy|`ZLIzXsojXq?tV-hHMY`X^qsI{u-x1_CS1h#ZS9{qed&{NR%)w z$tTq5f#?!PoNZa`>7v8Emn)tG<~$+a zz+6ejFy@-pkXUp+nMAdn&zDPWEh{Kr0W{;gz1q-!%e90*(AF1tQ8XO@(Z-qRNS3o{ zZHFv^#j5=l7N|c2%?GK#n74OMj@Z^BD^Gp!Uok?sPk#1AO%Nq|GUG5oq*~YK0LWDUDyh$1wBbJcquZU$lcC`9x*qRc^V5o zfl4Q1Xv&YJyaJoF`q2Bz?Z#T+UE{+P$fSs=r1u!YLwp74r0n<1nrIYT2NbCtZT1w} zGJn7kyY46nG$AK~gPAN1sKDY5*6J}J(((ysU^Vau@8_t%-=doSaCiH|n8N4s5L!m= zWaEL}USfMS5UK6A=`{CBBG_`&qgX8LbGaCV}nT+B{&-$(lW}9VfJc zbw!&fCKy#GnYx&cGmlZ*?)0Jsqb}0u6YX$Uv8%$?2<2;Vzy{+VU|^fgeoh)7DWQTn zC8?d?ok;Y)r|yThGF4JYBFA(Ss@>W~!N%p3F|LIzYXFR$H| zA$+huy9reCx8~Xis=O{J0W@Co{@PdstDcZ{^&28EIM~t=%fW6l@wCWlOs)Ev`(Q8F zG~bi0&Z=E*FHtCz)cZZzigf{yXnX3HfUxDUMQstf;aS?R-a8@rnuGpQefOjzdG%kY zG#3XjLKj7fd^|A<3W{{(Lu{~LJ|vFI0wirt5A6t< z4!4FUABg!+l}>s#FH{axkL!wmu z#!;pfK+HXoIYxs2-Q_U84`9X?M~fNF2eSMC(4e&aM20d=j&}JSB(fUjO>+7@NTv$y z^UwT5OMU6lDQ5ur>eeg?U4CaA7o+{PB->J&Ttg>9$T=t8`Y*2#Xm*Lg?}Sgi5kq?x zQdVO<^eHg&a3vu6S+pOq?C>YfTRXc5m{A}t zseapb7=?gRxjnSXWeXLXz%#*-0Kow;{CvPC%38hu^{`PGpxVJkmwJ+~-{Y|k0Z7eL zfIEh(-AW&p8c4b=^`-(Wxx1RB*388P*1z~qPj`MsT7m`lWB}&zlX0A12tGjn`}tM4 z<%^)(W}l+4{^nF&q_@|_*@24J8|^T7*hqxY(YT+*CZE&InOz}JT0E~BnO-=c1SfUV zq`p}5K|0x7iMZBay`1f}VL|K_Z7CiL|5>@5_n@!JI0b(B0d6XKwAK4oNnyxd`SqS`G(tCXcN zWEWrf^gfTorWc> zl^Ea=wVYO6o4jD-dvU(IK6r;fa=(Wwy?^;KG60b3-8Z2lxXM$c(m(3315<^07El7Fr6hd5B^O@f<((|H^5p;CKK7rPQUG>u|D~!yacPY_Y}Ac{&|C z4wJEaQ-^KnJ+AI*H|IdEF0B&Fw=wUDq7QAAcIk}2H;oz!z2+BEQj z^0em)cFzs;Xue}I{sNHs9-W;>F{q~Cqexr;++mp0uQHddq+P|6{k2rL#<=AhM2}z1 zHiyz{mE`Oue$^Oa5;`VhH}GWuD{}HFhS&w(F$Lov_x{d#>*3Y;NywH6{eC#JW~vZD)TAK2Ini#fi3RJfg|| zz$o?TCCa7PKO!UptKb8lqCCi)Nv{0K`k-*8)~n_xd^j^ntJJzz;qJ7KffK(RjH&m| z`>B=6jT#(?5{gMF5~60qP!5bkP8HPrzA0v(gfVPyivq19Ic{ed)8LJgyIp@xJ&nTcjtwIa}|w zDYjk@J1X=y-yd4hwEy-hQ@D`4N_*k3hL%1?c)_cF2)J=(cYjvh{gUMmD*!uC&9fu^h5@o?!+%ZjM8BKbsJOObBh_7gwn%Lg}MYdBIA z64&edDXsrcnAjR4R7~VgZXN<|!vuzgdYMeX-Hi*%B9fb`t`+W0|Nq+i&akGsrENe| zKm|ma2vVdd2}S9m(m|vMA|><|GzbWxsB~0%lNLG%NRuMHSLxDGklrMq(xeC??Oi)j zpYuHDobS56|HsQe0e1G@Yt5RuXXc(674O)HV{GnU(F5q9BjWBd0dtCXyXpS6CIW=} z!Xrb!xd1R5zJ)Acce7FyBm}M1XaLp_u-Ek#6DGs_EP@}jrWkxoNqHQ2y+XXy+B0gg zEX|Nd1l3(VN@u2s&=SuaJ`Z86m^jbDvgqPs`0C^%YS4^TDomO;1wG%=sjah}DKgJD z;`y+;OcMBV89mtr#}pdx7-52M5rL<&y1PyNV;4wpovwehx>sPa&#Qds=KROiv?D=C@^Z5XcU&uLScQE-U(Iu$~a zW#qPdD(vk9LXvAJ(eSEa#q`@06psZ@9lqq`P;Afv!RR@PWDbbYqescVan+)_b3K#P zVdpCkinO&x&SU!PckwKZ;%lu3Kk_m>TJklCG=K~vlg;j_XKT@oSJxZ`P0XXspX}_Z zx9ioA`nXOJ=EYS$pi+foLD95O9a{AQ2EyWeGbeTgc~30B{EVAo!RfHFm+s4CSvqxD zg9HA@5q?eAgoU0djU2pg;RQx6A9?*12fe|li1Zu6Ppss5h(e2VJmdZs`oDEAvVba)ZX2Cma0kB1!`ti}kKswflyRJi=|wjQ-(cN#*M zAFO6~c;G4A>{U(D7!;ELNXH;#q0MNS5+(K^`T0bPT_v0Z0EG|i2zuc8a{9N(Xf5$uVOB3$?u zBu!3i#&bCP24%e8mYhx=_*kTP3ZAr)h8zfZnjQ@IjUwnCC6`T%!-7z`j|5Zy(`Gh$ zLvSVZpOWy8UfTQ+90i^R#1+vrHYB}XiG3wfnt&=a`-bUwf7K^S#;JO0gh1x|;Bs9M zxGx1`0{IR|NyGZe7e!VY^DGpUlSKg)Mr+-tpB#yhsj!N}acpE8{BVxM?Sv{n^q}}O zc48uBAl;_;MwWUyb={ZYW;99`p=E?=(ks@_wLNcTC_tkaao^tWlvxLo9s`qP6h{3 zz8X8BG@0{X4lY5<8KX$6*XHKTp&CixPt_Xg;L5f^02D`va@{wazr^C@LHfD0HUV^n z%2Aat?%X<(f}Wd6%Y=BRv(5g&_BUWniTm+aUBnI+gfyLnRld$uJl^4atGl?SYR_k+#GyH zoPFg2SC`Dl9dEV#NZ9y0S};bkXJ)>6cl$up_6N1BUerW@(ty=q;M4Rvcwj8{7|*V$ z@3skbd-Rd${PFMsnOxNNq#z(=wf3WJO|8;LX_I81K9mhT z&y?AtFAo%27&mg@@Ky7Qn*u0yW~9C}NA-2PdGv#$hqID^-;|TJSJlet)ku?8{WM6l zK~UXSslT=2$rfO3J)wj0A>m>dUR-(zSX;;fizm>_cI%jla^em*qZgkwAD6oXclCW8 zIgDJx^R6y~AICBI<3VGO&fi*$RG?TqyW)&xmhWoL(SFE(kU2?N{LW+fnjeT5H5Wzc zaWrn-*FEw}-(otCLl&WCwSiP{^Rw?i;%q*{x@wGGc`Yr_YPZA6aP!E?6U2miuPK!d5Y$8U_7P&Ibd-Zapt#XE()GD3lcAV;OFBAQ+%MFn2@R`910!Jb%-(+ zSAYK4CMxjlQ9v0R29&W^S4TYrPG(Z@Di}uR#?74o@-@U>ID8S3U{P+bf}zVW#F!PB zvrtH!W;yZDc6-30Ih0PDtX?NTJ=OGF$wLDQgE|{(i(54=3%){&k4%2PfxUV^{B;m5 z6LCy)+A!~AIQx(VjQ=`10=NGa|A_2;`PBa_vKQwa1u)P9jJ{P2ByEgQ4GW7gEn4=o z^`+Khw7@41q^8YjFA~kaYUH0$;pTo}AYJ9`ddRFpDF~7Hz&exPR=9r00MB-c}y-u~YH17Jg$Txjf4+QvSZj?BDpjrV&_Umas&F zGQh~)b^)+VD{qW7PxFTBCS3o-+}l2>#PsiNAx`R$eKMMQ1nBM^tl8{C9ehUnbW|Jx zq1`!|=8io<#-kuQ5%O>@CHVFrj~g=* z;f5fGv7W!?8t`ExXQp7A2r?y#J^UCjob8VQTp_$1@4Elf@k=s@PsqTVer&F;74YQstZ6vVq4yN0DhIE4u9; zik~HJx+aOl3ZzFlT`L%sgSOH`ujagl#R0Fix~S_-ZL29q`WCW<{>>5v2pN_BY+A!v zEG#LB7p=B8+3DWq02ja(Gp*UBNRn+kFEKl`y!&He@SGJwoMwZE|hk!zX8o6S3? zTLn@Vx)Vr%*|pyWdAXZOl~?4Q9oya_&-GX9Dg6LdP)Yl(w=#dC93P(o$3@saZ)O;0 zwr5CzrU3JlChE=wk8EfujWv^|tWCO5f7f&MJ+m7vU!mdkSXg_;XQ0h zWrsr>!7ZyD8i1t>^v~mv^qEZFRP4uRcJqDcy+S*Yd0u^ApstlNe+p%3s9!y)5EuxZ zlQ3v)m*?Gsiso6#XIB>C-!mV_0b+`~r9|xlubc5?#ywh@6x7Kacs@dX1@-q}Y zp)>LuGI$byU^S*e&>okA=$q$6RvQS?lS8XQc%8CahUQRKO>{dWYb|5`ukSWz3EF&M=-dc1crV z=a#4Hel)Q8=<^~It01EmjXe=<^U^gqAo?$Z?P9VSRrs)(8hA+Lde z)Ri52ErG@_=)vYz&4|+?s#eNzh17;e#8ZsF&<<36DoQ7FyPOLe_<}8J++NnZo2NsS zeTQ>@R=}dx3#Z zN>XRba9WwHaHNRc7;3Bn@%{Q*k~(?S7_R+pxyn2FcamB_B0YW0XZr<{X*M-UXtma@ z7rhW9g~G^^)-!!pF9}*z4E@}naVfgMKn~F0>n*DVA&;Rx5{Yl(E(hG}v=0Lo?+snh z@DbD{S1&JHX47e8wveCy&|%zexsabcevn@KEmZ+3CIzWx1bZ9SL z%iO;O-OQnYfY+eX|Hs;_DM)U$$6hykz6(1ZiUo~!UUe8UZNITPC1jTMU8mG)iuWmp zMYQ+D<3?d&7Ch+f`Q=@DL5C`bTdyjbg}s}OIZEX-#tNRuG%(F3rbHr0uisF@;n0_t z2zj2%1QG^@z5!13Sqm`DfLgMuy%BPBzPd} zgVBscVpy*A;oQyQ(Mp&FW+U2)bAMT8B_XzOVwYbZjciTYZ?pS3y?Qh3nm^~|QVn<- znO_m3CfYRwf8gzjEKm4q9dZBu^VGPd0~yM2Z?~mWB%R_-Gx^!m%Ra0Yn4AwgoLwow z45Js`e(%w(4IOZ*XV*#Y9p+x{&02W0GJbZYaM^NUCG69JT7Rv;ivD=eCoaC{uCls+ zkwb(UzoqL~4rF&Qq{HLkER5vPK0*OVaPK6m7Ngkc?9ub$hY8+G2!YQdr890LiUs-x z26mtUnwoNO3}Q>ni8jj8v9Y#O&vx_26GcpPFSB<|OT2OWHdEj0v!T^P+-sobDGXVS zhxGxZI-rId!j{Z~uxRPFvQux7pw0kiK3QRaG9Vbi+9E}kMp)gxo^Mdy_fqP}HID;* zXC<+4#vqyHahgooD$Ms67-I+26iMx|AO(<^&ebVCwGf&+uw+bC-F!zeNds_{J{9Bd z7e}S0CtAXgtMdpR!hEn<<@AZyd6VgxkWy<3@Ze^gmKS$dB)AUU-0f$iEBlAiLbRH+ zYYnPNw7%t{&Djw!Xn0wD9{n||s!iWF5nYgMvexu=|i(j8{S+#*!RL{ZIc{Ccg^Kuf|1 zzs+L7%pc~1uHyUE$Os0VYYPe-6Q6&=k1EBV*>k3d@l4L-3oGLQLr!Q_lj@#3myQ>h zHr^kkxG>YT-o||PT~R*vSt4F5%KCUlnLk-uSE1k7AYvT9%v``Bv^w zf(ekvDU#*=IGmPZZYErmGVeR<3>X=1Pzu$@fo(42lz@z?V~z7Zc@a$(*K)<3VoQ{Y zX^+l&#e=_)L6%GdcaiJj))DhAjT3-SL%wZ^d?9GY;jr9CpEI=Vxt%dGpn#Y$IKAAj z4^pw?RT~7-SxU5>Z^-X5>EhI4o2tI2UU$f2cNaP^E$KW}i{C`&<_x96-$o#Jg~@R-VsT=?@}tYMrNcK zJ_h0)!6dQ82poD`PyTo_vWMrbr^stissZXyt%K#%WS;#i zqd<*4EgptvfDh_(Lf(q$qeuHIevf*PJqK~e%qlg<`@uM4l`abMi+=ePcBVP*OF|Ix zPcbhz0m&(gR~A)4IIy{qw}&SrHB+FnUCdLGW$CChkgcISN;rR?m~qvh*Q0h)5P ztF^ydX}QmMRLKogoLkJ4_5w#H_}%p_yq;uv<~eANXoYoDwe+smI_Qoipk(0RO^m*E zsskQu5!D2nh~1)SAaOtwt^!AaeHw(+ObzPf6;QrM3Uho0d7-qlrP#Ut$Co;t0yw8| zqtdP98%{&yo^K3dFm|-5^;bZr0HGt*rQHX-x1ui^S7w14?TFvbH|5w#OAyM?_eqg6 zEL6e0B4HGXC=`p{tJ`F zX_=#1{mrkyA7l1T?&KlHqfYdk@S+ooBC6aRi0NX?C(` zzI*lc?T11kXj@%B4P7_y(L26~QSlu9&h~>r%;fw#4EJ+%x_-KZF)*UJmaqQJjNe0$ z-w3nbUr)B{TM(=}G`T`H*9rRBtwqqOrCFlju>wR8y#X@VI81e*#g|@!>sR5ForxtXt|l0Z~enk`8cz%$6?nf8{#n|1@E5^lt~U6s3(W!^eJJ zo+SjWP)bHR3qNnG#N03;aNaz3VL5ZW)?TC8)~R*g+9Pe-bFISpD+T*)jE+E^t`IK3 z@h?03(#|>#H6_vLg;5L8aQ|vFv!EKSdkV z*-)}NLJMd=j>cby$)KYH2$GSij?_y0RV5Ff8y;y*%gQ;Sl{@0Ip#Z7?;;7?fV<6lg zu45I6hXG7PfH(iO@5*f$4#dl9<0FJ99v9!0rZ51-PvY&-eOlLl?hizjW#}h@?&`kZ zN_Qp$j+y)MzBaPO2|e4SmiIgs-07UU%*E&Tj( zrX9m^T|Ov2&?L>hF=H@7NPIaATpXWUV=IccNXb;Tfc!~D_O^C6wCSYl0J}e>xRY`x z$5H6ZXcfQBeL1nov?<^Y+&Pf2Xx^ZE!|$@r@*xMCfe886z$O9dMXqU*qCK2tGh$WN znSB}lv$W~pI@Y1|0YKlu@Z-e9P*V8w40Wxr4x5uw5h zq}P{`jcC@A=B1gkMBnM1_wa;0mspE7`W(t6+jv}%`JXIE61#$KNX5AApPcKAP>FbP zB7Y6=&~~Z2+an8E<$K?VualTuid8 z-`-uk`^g8K>t1-^3x6b?UKuk94Jd<4Vw2jdfLIJYJGpn=_OC!@p+djl3I1qv5Ceim z)9ePIR8>%02Mq?cSqhtFvGW=^V<3|<^_O=*DWbacf-)$9p8}~wBMtlmTbQ2CcL{wE z5y0~8*KBO#y}FY)ZT})rsMC}yh}THu3O|whVjRvqVXw{-ghOcmDE2qnXMPG+9h;>| zctZ{k!B0S&3EJ3z;B3m{$+`^!I=-pqH8&+Seh*axotzO;NXWDVZiJyAr&)p1=rJFD zg?gYDX$hLieUc*+M)N>TiMsoEv_8UtGhg))HRu6!HP$>tEA^=VZP*h$40_>7N($`7 zaS&3P6BD+XIKEczy}lsJQEti^K6Oop$Rhq3xy{dP>IF8gYb6>kpU)j2e5i>#+GusCtfAC25gby;xjAbZhq@RS4{_E z>aq;d60k(75Yv_y1*|XGS-kaS^z{UBcfa{J8y{198%Uu*ig|YexhYCM%{i*T^&m^F zxLeB~SBVj&A0?og52)_`+jzr6FAa_gqTv(?Ni%RjY!7A~=em+ka|B{Vo>wv?LWE$; zxiG{im_@)r^{dvY!GlCJ0G?cJd#Zk~>{y!b_ejtyy$XGT;X}3`Ja$0$3JfB)C~?|| zKFa`$5{F-ug0nz6uwYIxn?};;{lRgjPx7167+1%zYY5YX>mD@Qd=Ceg=*@UEyrI)` zrGC=}FxN%wws09Qc9wYXis)}P^dPO-+B0$E{^x$W+cry*<*V{(4*o><7u6Mn`HIJ( z`iyuEx)38ow}2GYb(7? zWqX$z-UE;4y+m^oR9JxaVR&HMhXE3|-sxdz_9cMFwj=5+wv*cD$26b8XynZpmZUOI zAAO-|7lzGR zY<5GLv3{{Ei^E8y5UPj=uZJ|Dvly682a3)275t`ZH=adAoJmMbyo-FTU8v1fM74ho*P6hy8$zE`Rr>^IH$J>gHS%q$)jiAJtwU-cF4+3oQGB`< zB@B|MRuGK+!;ml25ZvK~9>`b#lCMdZ&{U0HX1kfsmXN0l;+NC!IvhWby0EWTn#+`{ zs_@WhLeLqlsDvofX#at`bXpYBe91RlNTXAOp}0}i=1)?@m2`*)s#fz>n_j5oc${FC zX{9Gr2!T=3eOHty-TDYuT-(M9HVvLSf_FErQ#FSM*6JrJjF}0L(RDO zSorx#^5Uz{oq2w76RN{fV6|4>B|(mqq|7wnoiVTs&z?TDT|M~3|142Zk*ztTTc}fO zPqEHEkLoE0HF=nq?fT1HVA^QG{X;DWaWOZ^uTLTk^ovN_^ zZbWo^)XNdcYErds3xvFeb^f1rF)xh|5-=x3|BIEIJyMU@(DEg-1l1hssrPSG%)K-1 zFvdx%^0@7>Zi>An9I+E@+;9-2!)3XCa45x*09_$99>H!W-6rU`VG>!m^TL}Ot+q_B`^n*N>hW11`SNw3*ihmFr;;~&(;_(b>Iw|4*gh* zSTH#z4jSe^26@Ujgep+9NUEp*^$tjnLjtSuDn983=qv!PM=LQaSuJ$R)o-4Ae*;Vt zwZ1RPk5_YlxG`F%k5wDykYy`m( zv4i1?bM|!kO{3|bAXE&Q@S*Mr1P=JYza% z!3OGpvIc+X>5En|KpcT*qcr^>YaaaB@*^C2aRa9>Q=mM+V}2nvpFtdXFbui$AO{ zl`nd@wo-O3{(yDeW0UkG@G|g_f(TD zpO6y47Mzyvz{~eSUr?xkix9*Xepqa%NzK@SD8^&0ee#W<FH4-jU@Yo`Pm{b=t(j zk=;b>E9*FK0zx7aB7x@h(B%47|KeyZn-%PP-%meF2M?_hwRtct>sO~1>fMD4_uiP+ zC4)aEK8SBPSy-b3Zv~C~4`X-_hby>x4TBXzje^SGBnL-ug&c~kzV-;yq3zTxFCky` z!#WN>?FDZ_NCZ0yo=WrC>e5mW*u>!D+p9tGBKc+2wJ&CXVP{mjtQOs5B`?88Qtq00RDah-GUyN?2-m6bqV^R z5cDPg`{Dohf?x`laxxosV7)Hx-%@WUI#d_0qRk3*-~lUd1;8${;C7uj6)f?GtEXfj z5rBTGc8LW86RMT_{h5-31(P;}qR0Ib68T$%h5FfFc>a98NLa{Ns@nAX`$+DR`C)IV zJUspD`$=R;{jgb%3_icUU-IQ0Nw~J9bKvjKEVm@#3VHAO|2$F@k!j_f_j-Rms}q@u zStj}bsls29MNQcm6*mc;KinhxQBiiPFAAQ-Kfk1>WV)wL{)bK4Kc6E#bwYGAP5!(x t_t9o_dPe!5&;MHj;{Pqde^`QV=N^hFXZv@fED69rMLAX3d}-sS{|EDc+P44z literal 0 HcmV?d00001 From 84ea669b603ea1b96e0e6ba1552d16becece4f7f Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 22 Feb 2023 16:04:09 +0000 Subject: [PATCH 43/81] metric used BLEU 3 and sefl BLEU 3 and returned target and input to GenDataIter --- instructor/real_data/instructor.py | 4 ++-- utils/data_loader.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 8a771162..cf8cd12f 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -69,10 +69,10 @@ def __init__(self, opt): self.clas_opt = None # Metrics - self.bleu = BLEU('BLEU', gram=[2, 3, 4, 5], if_use=cfg.use_bleu) + self.bleu = BLEU('BLEU', gram=3, if_use=cfg.use_bleu) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.self_bleu = BLEU('Self-BLEU', gram=[2, 3, 4], if_use=cfg.use_self_bleu) + self.self_bleu = BLEU('Self-BLEU', gram=3, if_use=cfg.use_self_bleu) self.clas_acc = ACC(if_use=cfg.use_clas_acc) self.ioc = IOC(if_use=cfg.use_ioc, real_text=self.test_data.tokens) self.nll_oracle = GPTNLL(if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) diff --git a/utils/data_loader.py b/utils/data_loader.py index ee44bc93..a774b3b6 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -52,8 +52,8 @@ def __init__(self, samples, if_test_data=False, shuffle=None): shuffle=self.shuffle, drop_last=True) - # self.input = self._all_data_('input') - # self.target = self._all_data_('target') + self.input = self._all_data_('input') + self.target = self._all_data_('target') def __read_data__(self, samples): """ From bc099e3802fbbaa38e54bf7650d2163a10cc0365 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 24 Feb 2023 10:27:39 +0000 Subject: [PATCH 44/81] try to fix issue with None, inf or 0 value in binomial --- models/generators/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/generators/generator.py b/models/generators/generator.py index e1ab3fa8..e9b45fbd 100644 --- a/models/generators/generator.py +++ b/models/generators/generator.py @@ -76,7 +76,7 @@ def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): for i in range(self.max_seq_len): out, hidden = self.forward(inp, hidden, need_hidden=True) # out: batch_size * vocab_size - next_token = torch.multinomial(torch.exp(out), 1) # batch_size * 1 (sampling from each row) + next_token = torch.multinomial(torch.exp(out), 1, replacement=True) # batch_size * 1 (sampling from each row) samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1) inp = next_token.view(-1) samples = samples[:num_samples] From bbf776691196f4c3ce1fa92bfeafd7fe335879c0 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 24 Feb 2023 10:27:58 +0000 Subject: [PATCH 45/81] improoving metrics calculation --- instructor/oracle_data/instructor.py | 35 +++++++++++++++----- instructor/oracle_data/leakgan_instructor.py | 2 +- instructor/real_data/instructor.py | 27 ++++++++------- instructor/real_data/leakgan_instructor.py | 2 +- metrics/basic.py | 29 +++++++++++----- metrics/bleu.py | 31 +++++------------ metrics/clas_acc.py | 20 ++++------- metrics/gpt_nll.py | 9 ++--- metrics/ioc.py | 17 ++++------ metrics/nll.py | 12 +++---- metrics/ppl.py | 12 +++---- 11 files changed, 99 insertions(+), 97 deletions(-) diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 1d182db3..452626c5 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -56,7 +56,8 @@ def __init__(self, opt): self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div] + self.self_bleu = BLEU('Self-BLEU', gram=3, if_use=cfg.use_self_bleu) + self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div, self.self_bleu] def _run(self): print('Nothing to run in Basic Instructor!') @@ -163,6 +164,20 @@ def show_config(self): self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) self.log.info(100 * '=') + def sample_for_metrics(self): + eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) + gen_data = GenDataIter(eval_samples) + gen_tokens = eval_samples + gen_tokens_s = self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size) + return gen_data, gen_tokens, gen_tokens_s + + def sample_for_metrics_with_label(self, label_i): + eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size, label_i=label_i) + gen_data = GenDataIter(eval_samples) + gen_tokens = eval_samples + gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) + return gen_data, gen_tokens, gen_tokens_s + def cal_metrics(self, fmt_str=False): """ Calculate metrics @@ -170,30 +185,32 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - gen_data = GenDataIter(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)) + gen_data, gen_tokens, gen_tokens_s = sample_for_metrics() # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) self.nll_gen.reset(self.gen, self.oracle_data.loader) self.nll_div.reset(self.gen, gen_data.loader) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) - else: - return [metric.get_score() for metric in self.all_metrics] + return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) + return [metric.get_score() for metric in self.all_metrics] - def cal_metrics_with_label(self, label_i): + def cal_metrics_with_label(self, label_i, fmt_str=False): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) - gen_data = GenDataIter(eval_samples) + gen_data, gen_tokens, gen_tokens_s = sample_for_metrics_with_label() # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader, label_i) self.nll_gen.reset(self.gen, self.oracle_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) + self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + if fmt_str: + return f'label: {label_i}' + ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): @@ -201,7 +218,7 @@ def comb_metrics(self, fmt_str=False): all_scores = np.array(all_scores).T.tolist() # each row for each metric if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), score) + return ', '.join(['%s = %s' % (metric.name, score) for (metric, score) in zip(self.all_metrics, all_scores)]) return all_scores diff --git a/instructor/oracle_data/leakgan_instructor.py b/instructor/oracle_data/leakgan_instructor.py index 5e2d7109..0ec40663 100644 --- a/instructor/oracle_data/leakgan_instructor.py +++ b/instructor/oracle_data/leakgan_instructor.py @@ -185,7 +185,7 @@ def cal_metrics(self, fmt_str=False): self.nll_div.reset(self.gen, gen_data.loader, leak_dis=self.dis) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) else: return [metric.get_score() for metric in self.all_metrics] diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index cf8cd12f..c9c2f159 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -69,14 +69,14 @@ def __init__(self, opt): self.clas_opt = None # Metrics - self.bleu = BLEU('BLEU', gram=3, if_use=cfg.use_bleu) - self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.self_bleu = BLEU('Self-BLEU', gram=3, if_use=cfg.use_self_bleu) - self.clas_acc = ACC(if_use=cfg.use_clas_acc) - self.ioc = IOC(if_use=cfg.use_ioc, real_text=self.test_data.tokens) - self.nll_oracle = GPTNLL(if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) - self.ppl = PPL(self.train_data, self.test_data, n_gram=5, if_use=cfg.use_ppl) + self.bleu = BLEU('BLEU', weitht=1, gram=3, if_use=cfg.use_bleu) + self.nll_gen = NLL('NLL_gen', weight=-1, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + self.nll_div = NLL('NLL_div', weight=1, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + self.self_bleu = BLEU('Self-BLEU', weight=-1, gram=3, if_use=cfg.use_self_bleu) + self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) + self.ioc = IOC(weight=-1, if_use=cfg.use_ioc, real_text=self.test_data.tokens) + self.nll_oracle = GPTNLL(weight=-1, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) + self.ppl = PPL(self.train_data, self.test_data, weight=-1, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] def _run(self): @@ -239,7 +239,9 @@ def cal_metrics(self, fmt_str=False): self.nll_oracle.reset(test_text=gen_tokens) if fmt_str: - return '\n'.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return "\n" \ + "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) \ + f"\n Overal_score: {sum(metric.weight * metric.get_score() for metric in metrics)}" return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): @@ -257,7 +259,9 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ppl.reset(gen_tokens) if fmt_str: - return '\n'.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return f"label: {label_i}" \ + '\n'.join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) \ + f"\n Overal_score: {sum(metric.weight * metric.get_score() for metric in metrics)}" return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): @@ -265,8 +269,7 @@ def comb_metrics(self, fmt_str=False): all_scores = np.array(all_scores).T.tolist() # each row for each metric if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), score) - for (metric, score) in zip(self.all_metrics, all_scores)]) + return ', '.join([f"{metric.name} = {score}" for (metric, score) in zip(self.all_metrics, all_scores)]) return all_scores def _save(self, phase, epoch): diff --git a/instructor/real_data/leakgan_instructor.py b/instructor/real_data/leakgan_instructor.py index bfff5993..fc02e41a 100644 --- a/instructor/real_data/leakgan_instructor.py +++ b/instructor/real_data/leakgan_instructor.py @@ -185,7 +185,7 @@ def cal_metrics(self, fmt_str=False): self.ppl.reset(gen_tokens) if fmt_str: - return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics]) + return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) else: return [metric.get_score() for metric in self.all_metrics] diff --git a/metrics/basic.py b/metrics/basic.py index b96377c0..5084c423 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -4,26 +4,39 @@ # @FileName : basic.py # @Time : Created at 2019-05-14 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. from abc import abstractmethod class Metrics: - def __init__(self, name='Metric'): + def __init__(self, name='Metric', wight=1): self.name = name + # represents effect on final score + # ex.: self-bleu has weight = -1 (less is better) + # bleu has weight = 1 (more is better) + # weights needed for combined metric evaluation + self.weight = weight + self.in_use = False + self.metric_value_with_current_state = None - def get_name(self): - return self.name + def get_score(self): + if not self.in_use: + return 0 - def set_name(self, name): - self.name = name + if self.metric_value_with_current_state is not None: + return self.metric_value_with_current_state + + self.metric_value_with_current_state = self.calculate_metric() + return self.metric_value_with_current_state @abstractmethod - def get_score(self): - pass + def calculate_metric(self) @abstractmethod def reset(self): pass + + def _reset(self): + self.metric_value_with_current_state = None diff --git a/metrics/bleu.py b/metrics/bleu.py index 6153ea64..4dd7545c 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -19,47 +19,34 @@ class BLEU(Metrics): - def __init__(self, name=None, test_text=None, real_text=None, gram=3, portion=1, if_use=False): + """ + Get BLEU scores. + :param is_fast: Fast mode + :param given_gram: Calculate specific n-gram BLEU score + """ + def __init__(self, weight, name=None, test_text=None, real_text=None, gram=3, portion=1, if_use=False): assert type(gram) == int or type(gram) == list, 'Gram format error!' - super(BLEU, self).__init__('%s-%s' % (name, gram)) + super(BLEU, self).__init__('%s-%s' % (name, gram), weight=weight) self.if_use = if_use self.test_text = test_text self.real_text = real_text self.gram = [gram] if type(gram) == int else gram self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 - self.reference = None - self.is_first = True self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset - def get_score(self, is_fast=True, given_gram=None): - """ - Get BLEU scores. - :param is_fast: Fast mode - :param given_gram: Calculate specific n-gram BLEU score - """ - if not self.if_use: - return 0 - if self.is_first: - self.get_reference() - self.is_first = False - if is_fast: - return self.get_bleu_fast(given_gram) - return self.get_bleu(given_gram) - def reset(self, test_text=None, real_text=None): + self._reset() self.test_text = test_text if test_text else self.test_text self.real_text = real_text if real_text else self.real_text def get_reference(self): reference = self.real_text.copy() - # randomly choose a portion of test data # In-place shuffle random.shuffle(reference) len_ref = len(reference) reference = reference[:int(self.portion * len_ref)] - self.reference = reference return reference def get_bleu(self, given_gram=None): @@ -84,7 +71,7 @@ def cal_bleu(reference, hypothesis, weight): return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, smoothing_function=SmoothingFunction().method1) - def get_bleu_fast(self, given_gram=None): + def calculate_metric(self, given_gram=None): reference = self.get_reference() if given_gram is not None: # for single gram return self.get_bleu_parallel(ngram=given_gram, reference=reference) diff --git a/metrics/clas_acc.py b/metrics/clas_acc.py index 01877055..72c8645f 100644 --- a/metrics/clas_acc.py +++ b/metrics/clas_acc.py @@ -4,7 +4,7 @@ # @FileName : clas_acc.py # @Time : Created at 2019/12/4 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import torch @@ -13,35 +13,29 @@ class ACC(Metrics): - def __init__(self, if_use=True, gpu=True): - super(ACC, self).__init__('clas_acc') + def __init__(self, weight, if_use=True, gpu=True): + super(ACC, self).__init__('clas_acc', weight=weight) self.if_use = if_use self.model = None self.data_loader = None self.gpu = gpu - def get_score(self): - if not self.if_use: - return 0 - assert self.model and self.data_loader, 'Need to reset() before get_score()!' - - return self.cal_acc(self.model, self.data_loader) - def reset(self, model=None, data_loader=None): + self._reset() self.model = model self.data_loader = data_loader - def cal_acc(self, model, data_loader): + def calculate_metric(self, model, data_loader): total_acc = 0 total_num = 0 with torch.no_grad(): - for i, data in enumerate(data_loader): + for i, data in enumerate(self.data_loader): inp, target = data['input'], data['target'] if self.gpu: inp, target = inp.cuda(), target.cuda() - pred = model.forward(inp) + pred = self.model.forward(inp) total_acc += torch.sum((pred.argmax(dim=-1) == target)).item() total_num += inp.size(0) return round(total_acc / total_num, 4) diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 72772bac..743ed546 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -13,8 +13,8 @@ class GPTNLL(Metrics): - def __init__(self, name=None, test_text=None, real_text=None, if_use=True): - super(GPTNLL, self).__init__('GPT2 as oracle') + def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): + super(GPTNLL, self).__init__('GPT2 as oracle', weight=weight) self.if_use = if_use self.test_text = test_text @@ -34,10 +34,11 @@ def get_score(self): return self.get_NLL(self.test_text) - self.real_text_nll def reset(self, test_text=None, real_text=None): + self._reset() self.test_text = test_text if test_text else self.test_text self.real_text_nll = self.get_NLL(real_text) if real_text else self.real_text_nll - def get_NLL(self, messages, baseline=0): + def calculate_metric(self, messages, baseline=0): if type(messages[0]) == list: #we received list of tokens messages = [' '.join(msg) for msg in messages] @@ -51,4 +52,4 @@ def get_NLL(self, messages, baseline=0): all_logits.append( self.NLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() ) - return np.mean(all_logits) + return np.mean(all_logits) - self.real_text_nll diff --git a/metrics/ioc.py b/metrics/ioc.py index 5ad98aff..eaad76a6 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -9,8 +9,8 @@ class IOC(Metrics): - def __init__(self, name=None, test_text=None, real_text=None, if_use=True): - super(IOC, self).__init__('Index of Coincidence') + def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): + super(IOC, self).__init__('Index of Coincidence', weight) self.if_use = if_use self.test_text = test_text @@ -19,20 +19,15 @@ def __init__(self, name=None, test_text=None, real_text=None, if_use=True): self.reference = None self.is_first = True - def get_score(self): - """Get IOC score.""" - if not self.if_use: - return 0 - return self.get_ioc(self.test_text) / self.real_text_ioc - def reset(self, test_text=None, real_text=None): + self._reset() self.test_text = test_text if test_text else self.test_text self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc - def get_ioc(self, list_tokens): + def calculate_metric(self): """Index Of coincidence: probability of 2 random tokens in text to equal.""" - tokens = list(chain(*list_tokens)) + tokens = list(chain(*self.test_text)) counts = Counter(tokens) total = sum(ni * (ni - 1) for ni in counts.values()) N = len(tokens) - return total / N / (N - 1) + return total / N / (N - 1) / self.real_text_ioc diff --git a/metrics/nll.py b/metrics/nll.py index 528d53e3..f5012631 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -15,9 +15,8 @@ class NLL(Metrics): - def __init__(self, name, if_use=False, gpu=False): - super(NLL, self).__init__(name) - + def __init__(self, name, weight, if_use=False, gpu=False): + super(NLL, self).__init__(name, weight) self.if_use = if_use self.model = None self.data_loader = None @@ -26,12 +25,8 @@ def __init__(self, name, if_use=False, gpu=False): self.gpu = gpu self.criterion = nn.NLLLoss() - def get_score(self): + def calculate_metric(self): """note that NLL score need the updated model and data loader each time, use reset() before get_score()""" - if not self.if_use: - return 0 - assert self.model and self.data_loader, 'Need to reset() before get_score()!' - if self.leak_dis is not None: # For LeakGAN return self.cal_nll_with_leak_dis(self.model, self.data_loader, self.leak_dis, self.gpu) if self.label_i is not None: # For category text generation @@ -40,6 +35,7 @@ def get_score(self): return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) def reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): + self._reset() self.model = model self.data_loader = data_loader self.label_i = label_i diff --git a/metrics/ppl.py b/metrics/ppl.py index 764d0a7e..96a71121 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -21,7 +21,7 @@ class PPL(Metrics): - def __init__(self, train_data, test_data, n_gram=5, if_use=False): + def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False, weight=1): """ Calculate Perplexity scores, including forward and reverse. PPL-F: PPL_forward, PPL-R: PPL_reverse @@ -30,7 +30,7 @@ def __init__(self, train_data, test_data, n_gram=5, if_use=False): @param n_gram: calculate with n-gram @param if_use: if use """ - super(PPL, self).__init__('[PPL-F, PPL-R]') + super(PPL, self).__init__('[PPL-F, PPL-R]', weight) self.n_gram = n_gram self.if_use = if_use @@ -39,15 +39,11 @@ def __init__(self, train_data, test_data, n_gram=5, if_use=False): self.train_data = train_data self.test_data = test_data - def get_score(self): - if not self.if_use: - return 0 - return self.cal_ppl() - def reset(self, gen_tokens=None): + self._reset() self.gen_tokens = gen_tokens - def cal_ppl(self): + def calculate_metric(self): save_path = os.path.join("/tmp", ''.join(random.choice( string.ascii_uppercase + string.digits) for _ in range(6))) output_path = save_path + ".arpa" From dd187782450da41857d3dfae4171848fff3beff1 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 24 Feb 2023 10:57:29 +0000 Subject: [PATCH 46/81] metrics typos --- instructor/real_data/instructor.py | 13 ++++++------- instructor/real_data/maligan_instructor.py | 2 +- metrics/basic.py | 9 +++++---- metrics/bleu.py | 4 ++-- metrics/clas_acc.py | 2 +- metrics/gpt_nll.py | 20 +++++++++----------- metrics/ioc.py | 11 +++++++---- metrics/nll.py | 2 +- metrics/ppl.py | 4 ++-- 9 files changed, 34 insertions(+), 33 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index c9c2f159..c49d7694 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -69,7 +69,7 @@ def __init__(self, opt): self.clas_opt = None # Metrics - self.bleu = BLEU('BLEU', weitht=1, gram=3, if_use=cfg.use_bleu) + self.bleu = BLEU('BLEU', weight=1, gram=3, if_use=cfg.use_bleu) self.nll_gen = NLL('NLL_gen', weight=-1, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) self.nll_div = NLL('NLL_div', weight=1, if_use=cfg.use_nll_div, gpu=cfg.CUDA) self.self_bleu = BLEU('Self-BLEU', weight=-1, gram=3, if_use=cfg.use_self_bleu) @@ -239,9 +239,8 @@ def cal_metrics(self, fmt_str=False): self.nll_oracle.reset(test_text=gen_tokens) if fmt_str: - return "\n" \ - "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) \ - f"\n Overal_score: {sum(metric.weight * metric.get_score() for metric in metrics)}" + return "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) + "" \ + f"\nOveral_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}" return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): @@ -249,7 +248,7 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): with torch.no_grad(): # Prepare data for evaluation - gen_data, gen_tokens, gen_tokens_s, clas_data = sample_for_metrics_with_label(label_i) + gen_data, gen_tokens, gen_tokens_s, clas_data = self.sample_for_metrics_with_label(label_i) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) @@ -260,8 +259,8 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): if fmt_str: return f"label: {label_i}" \ - '\n'.join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) \ - f"\n Overal_score: {sum(metric.weight * metric.get_score() for metric in metrics)}" + "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) + "" \ + f"\nOveral_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}" return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index a27356f8..5224b734 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -89,7 +89,7 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + '[MLE-GEN] epoch %d : pre_loss = %.4f\n%s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: diff --git a/metrics/basic.py b/metrics/basic.py index 5084c423..b5a4b8af 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -11,18 +11,18 @@ class Metrics: - def __init__(self, name='Metric', wight=1): + def __init__(self, name, weight, if_use): self.name = name # represents effect on final score # ex.: self-bleu has weight = -1 (less is better) # bleu has weight = 1 (more is better) # weights needed for combined metric evaluation self.weight = weight - self.in_use = False + self.if_use = if_use self.metric_value_with_current_state = None def get_score(self): - if not self.in_use: + if not self.if_use: return 0 if self.metric_value_with_current_state is not None: @@ -32,7 +32,8 @@ def get_score(self): return self.metric_value_with_current_state @abstractmethod - def calculate_metric(self) + def calculate_metric(self): + pass @abstractmethod def reset(self): diff --git a/metrics/bleu.py b/metrics/bleu.py index 4dd7545c..3a174397 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -24,9 +24,9 @@ class BLEU(Metrics): :param is_fast: Fast mode :param given_gram: Calculate specific n-gram BLEU score """ - def __init__(self, weight, name=None, test_text=None, real_text=None, gram=3, portion=1, if_use=False): + def __init__(self, name=None, weight=1, test_text=None, real_text=None, gram=3, portion=1, if_use=False): assert type(gram) == int or type(gram) == list, 'Gram format error!' - super(BLEU, self).__init__('%s-%s' % (name, gram), weight=weight) + super(BLEU, self).__init__('%s-%s' % (name, gram), weight, if_use) self.if_use = if_use self.test_text = test_text diff --git a/metrics/clas_acc.py b/metrics/clas_acc.py index 72c8645f..c2b5c229 100644 --- a/metrics/clas_acc.py +++ b/metrics/clas_acc.py @@ -14,7 +14,7 @@ class ACC(Metrics): def __init__(self, weight, if_use=True, gpu=True): - super(ACC, self).__init__('clas_acc', weight=weight) + super(ACC, self).__init__('clas_acc', weight, if_use) self.if_use = if_use self.model = None diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 743ed546..33acb33c 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -14,7 +14,7 @@ class GPTNLL(Metrics): def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): - super(GPTNLL, self).__init__('GPT2 as oracle', weight=weight) + super(GPTNLL, self).__init__('GPT2 as oracle', weight, if_use) self.if_use = if_use self.test_text = test_text @@ -23,22 +23,20 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.model = GPT2LMHeadModel.from_pretrained("gpt2") print('Calculating dataset NLL') - self.real_text_nll = self.get_NLL(random.sample(real_text, 500)) if real_text else None + self.real_text_nll = self.calcualte_NLL(random.sample(real_text, 500)) if real_text else None print(f'dataset NLL based on GPT2 is {self.real_text_nll}') print('GPT2 as oracle metric will be calculated relative to this value') - def get_score(self): - """Get gpt2 NLL score.""" - if not self.if_use: - return 0 - return self.get_NLL(self.test_text) - self.real_text_nll - def reset(self, test_text=None, real_text=None): self._reset() self.test_text = test_text if test_text else self.test_text - self.real_text_nll = self.get_NLL(real_text) if real_text else self.real_text_nll + self.real_text_nll = self.calcualte_NLL(real_text) if real_text else self.real_text_nll + + def calculate_metric(self): + """Get gpt2 NLL score difference with dataset NLL.""" + return self.calcualte_NLL(self.test_text) - self.real_text_nll - def calculate_metric(self, messages, baseline=0): + def calcualte_NLL(self, messages): if type(messages[0]) == list: #we received list of tokens messages = [' '.join(msg) for msg in messages] @@ -52,4 +50,4 @@ def calculate_metric(self, messages, baseline=0): all_logits.append( self.NLLloss(logits[:-1], inputs["input_ids"][0][1:]).detach().numpy() ) - return np.mean(all_logits) - self.real_text_nll + return np.mean(all_logits) diff --git a/metrics/ioc.py b/metrics/ioc.py index eaad76a6..eb995f3e 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -10,11 +10,11 @@ class IOC(Metrics): def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): - super(IOC, self).__init__('Index of Coincidence', weight) + super(IOC, self).__init__('Index of Coincidence', weight, if_use) self.if_use = if_use self.test_text = test_text - self.real_text_ioc = self.get_ioc(real_text) if real_text else None + self.real_text_ioc = self.calculate_ioc(real_text) if real_text else None print(f'Dataset Index of coincidence: {self.real_text_ioc}') self.reference = None self.is_first = True @@ -25,9 +25,12 @@ def reset(self, test_text=None, real_text=None): self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc def calculate_metric(self): + return self.calculate_ioc(self.test_text) / self.real_text_ioc + + def calculate_ioc(self, tokenized_text): """Index Of coincidence: probability of 2 random tokens in text to equal.""" - tokens = list(chain(*self.test_text)) + tokens = list(chain(*tokenized_text)) counts = Counter(tokens) total = sum(ni * (ni - 1) for ni in counts.values()) N = len(tokens) - return total / N / (N - 1) / self.real_text_ioc + return total / N / (N - 1) diff --git a/metrics/nll.py b/metrics/nll.py index f5012631..f8323ed2 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -16,7 +16,7 @@ class NLL(Metrics): def __init__(self, name, weight, if_use=False, gpu=False): - super(NLL, self).__init__(name, weight) + super(NLL, self).__init__(name, weight, if_use) self.if_use = if_use self.model = None self.data_loader = None diff --git a/metrics/ppl.py b/metrics/ppl.py index 96a71121..16e24b6f 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -21,7 +21,7 @@ class PPL(Metrics): - def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False, weight=1): + def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False): """ Calculate Perplexity scores, including forward and reverse. PPL-F: PPL_forward, PPL-R: PPL_reverse @@ -30,7 +30,7 @@ def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False, weight @param n_gram: calculate with n-gram @param if_use: if use """ - super(PPL, self).__init__('[PPL-F, PPL-R]', weight) + super(PPL, self).__init__('[PPL-F, PPL-R]', weight, if_use) self.n_gram = n_gram self.if_use = if_use From d2b741d9bfa143b56556a3f64839abc4264642ae Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 24 Feb 2023 12:21:03 +0000 Subject: [PATCH 47/81] metrics typos --- instructor/real_data/instructor.py | 11 +++++----- instructor/real_data/maligan_instructor.py | 4 ++-- metrics/bleu.py | 24 +++++++++++----------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index c49d7694..8fd06282 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -239,8 +239,9 @@ def cal_metrics(self, fmt_str=False): self.nll_oracle.reset(test_text=gen_tokens) if fmt_str: - return "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) + "" \ - f"\nOveral_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}" + pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] + pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") + return "\n".join(pp_metrics) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): @@ -258,9 +259,9 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ppl.reset(gen_tokens) if fmt_str: - return f"label: {label_i}" \ - "\n".join([f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics]) + "" \ - f"\nOveral_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}" + pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] + pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") + return f"label: {label_i}\n" + "\n".join(pp_metrics) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index 5224b734..de323e96 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -55,7 +55,7 @@ def _run(self): # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info('Initial generator: %s\n' % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) @@ -111,7 +111,7 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info('[ADV-GEN]: g_loss = %.4f\n%s' % (total_g_loss, self.cal_metrics(fmt_str=True))) def train_discriminator(self, d_step, d_epoch, phase='MLE'): """ diff --git a/metrics/bleu.py b/metrics/bleu.py index 3a174397..9e04d1ea 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -31,7 +31,7 @@ def __init__(self, name=None, weight=1, test_text=None, real_text=None, gram=3, self.if_use = if_use self.test_text = test_text self.real_text = real_text - self.gram = [gram] if type(gram) == int else gram + self.gram = gram if type(gram) == int else gram self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset @@ -50,8 +50,8 @@ def get_reference(self): return reference def get_bleu(self, given_gram=None): - if given_gram is not None: # for single gram - return self.get_blue_for_single_gram(given_gram) + if type(self.gram) == int: # for single gram + return self.get_blue_for_single_gram(self.gram) # for multiple gram all_bleu = [] for ngram in self.gram: @@ -62,7 +62,7 @@ def get_blue_for_single_gram(self, ngram): bleu = list() reference = self.get_reference() weight = tuple((1. / ngram for _ in range(ngram))) - for idx, hypothesis in enumerate(tqdm(self.test_text[:self.sample_size], desc=self.name)): + for idx, hypothesis in enumerate(self.test_text[:self.sample_size], desc=self.name): bleu.append(self.cal_bleu(reference, hypothesis, weight)) return round(sum(bleu) / len(bleu), 3) @@ -71,15 +71,15 @@ def cal_bleu(reference, hypothesis, weight): return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, smoothing_function=SmoothingFunction().method1) - def calculate_metric(self, given_gram=None): + def calculate_metric(self): + if type(self.gram) == int: # for single gram + return self.get_blue_for_single_gram(self.gram) + # for multiple gram reference = self.get_reference() - if given_gram is not None: # for single gram - return self.get_bleu_parallel(ngram=given_gram, reference=reference) - else: # for multiple gram - all_bleu = [] - for ngram in self.gram: - all_bleu.append(self.get_bleu_parallel(ngram=ngram, reference=reference)) - return all_bleu + all_bleu = [] + for ngram in self.gram: + all_bleu.append(self.get_bleu_parallel(ngram=ngram, reference=reference)) + return all_bleu def get_bleu_parallel(self, ngram, reference): weight = tuple((1. / ngram for _ in range(ngram))) From 27d27164e519b1e07976cd666eb35ad564757179 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 24 Feb 2023 12:36:47 +0000 Subject: [PATCH 48/81] some prettyfying on metric print --- instructor/real_data/evogan_instructor.py | 5 ----- instructor/real_data/fixem_instructor.py | 4 ---- instructor/real_data/instructor.py | 4 ++-- instructor/real_data/maligan_instructor.py | 6 +++--- metrics/bleu.py | 2 +- 5 files changed, 6 insertions(+), 15 deletions(-) diff --git a/instructor/real_data/evogan_instructor.py b/instructor/real_data/evogan_instructor.py index f8aa1211..eef84bef 100644 --- a/instructor/real_data/evogan_instructor.py +++ b/instructor/real_data/evogan_instructor.py @@ -137,11 +137,6 @@ def pretrain_generator(self, epochs): self.log.info( '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) - eval_samples = self.gen.sample(20, 20) - gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - for sample in gen_tokens: - self.log.info(' '.join(sample)) - if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 3f4b412f..1c6cb434 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -123,10 +123,6 @@ def _run(self): while self.one_more_batch_for_generator(generator_acc): generator_acc = self.generator_train_one_batch() - - self.log.info('\n'.join(self.gen.sample(20, 20))) - - # if (i + 1) % 10 == 0: if cfg.run_model == 'fixemgan': scores = self.cal_metrics(fmt_str=True) if cfg.run_model == 'cat_fixemgan': diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 8fd06282..861b04cd 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -241,7 +241,7 @@ def cal_metrics(self, fmt_str=False): if fmt_str: pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") - return "\n".join(pp_metrics) + return "\n" + "\n".join(pp_metrics) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): @@ -261,7 +261,7 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): if fmt_str: pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") - return f"label: {label_i}\n" + "\n".join(pp_metrics) + return "\n" + f"label: {label_i}\n" + "\n".join(pp_metrics) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index de323e96..a27356f8 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -55,7 +55,7 @@ def _run(self): # ===ADVERSARIAL TRAINING=== self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s\n' % (self.cal_metrics(fmt_str=True))) + self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) @@ -89,7 +89,7 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f\n%s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) if cfg.if_save and not cfg.if_test: self._save('MLE', epoch) else: @@ -111,7 +111,7 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f\n%s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) def train_discriminator(self, d_step, d_epoch, phase='MLE'): """ diff --git a/metrics/bleu.py b/metrics/bleu.py index 9e04d1ea..432ea08a 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -62,7 +62,7 @@ def get_blue_for_single_gram(self, ngram): bleu = list() reference = self.get_reference() weight = tuple((1. / ngram for _ in range(ngram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size], desc=self.name): + for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): bleu.append(self.cal_bleu(reference, hypothesis, weight)) return round(sum(bleu) / len(bleu), 3) From cac3ec0e7feca13032de3a9317ba1b1a29effd33 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sat, 25 Feb 2023 06:22:10 +0000 Subject: [PATCH 49/81] added weighted metrics to more representitive Overall score --- instructor/real_data/instructor.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 861b04cd..13db8e30 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -69,14 +69,22 @@ def __init__(self, opt): self.clas_opt = None # Metrics - self.bleu = BLEU('BLEU', weight=1, gram=3, if_use=cfg.use_bleu) - self.nll_gen = NLL('NLL_gen', weight=-1, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - self.nll_div = NLL('NLL_div', weight=1, if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.self_bleu = BLEU('Self-BLEU', weight=-1, gram=3, if_use=cfg.use_self_bleu) + # bleu, more-bettter, changes in range 0.3 - 0.4, will have relatively high weight + self.bleu = BLEU('BLEU', weight=3, gram=3, if_use=cfg.use_bleu) + # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight + self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight + self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + # self-bleu, less-bettter, changes in range 0.7 - 0.9, will have relatively high weight + self.self_bleu = BLEU('Self-BLEU', weight=-2, gram=3, if_use=cfg.use_self_bleu) + # class-acc, more-bettter, changes in range 0.7 - 1.0, moderate weight self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) - self.ioc = IOC(weight=-1, if_use=cfg.use_ioc, real_text=self.test_data.tokens) + # IOC, less-bettter, changes in range 0.8 - 2.0, smaller weight + self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data.tokens) + # nll_oracle, less-bettter, changes in range -0.1 - 0.5, moderate weight self.nll_oracle = GPTNLL(weight=-1, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) - self.ppl = PPL(self.train_data, self.test_data, weight=-1, n_gram=5, if_use=cfg.use_ppl) + # perplexity, less-bettter, changes in range 3 - 4, moderate weight + self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] def _run(self): From f54c174400d39665322bcded2e46204c6b6e134a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 28 Feb 2023 17:42:08 +0000 Subject: [PATCH 50/81] fixed data loader for list of tockens --- utils/data_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/data_loader.py b/utils/data_loader.py index a774b3b6..368ac35c 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -59,15 +59,15 @@ def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - # global all_data + all_data = None if isinstance(samples, torch.Tensor): # Tensor inp, target = self.prepare(samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] elif isinstance(samples, str): # filename inp, target = self.load_data(samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] - else: - all_data = None + elif isinstance(samples, list): # list of tockens + all_data = [{'input': i, 'target': t} for (i, t) in zip(samples[:-1], samples[1:])] return all_data def random_batch(self): From 12b8c28799e18adfbd68073bef78b09c4a1ac45d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 28 Feb 2023 17:50:11 +0000 Subject: [PATCH 51/81] fixed data loader for list of tockens --- utils/data_loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/data_loader.py b/utils/data_loader.py index 368ac35c..29fea025 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -66,8 +66,11 @@ def __read_data__(self, samples): elif isinstance(samples, str): # filename inp, target = self.load_data(samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] - elif isinstance(samples, list): # list of tockens - all_data = [{'input': i, 'target': t} for (i, t) in zip(samples[:-1], samples[1:])] + elif isinstance(samples, list): # list of tokens, required for generator NLL + all_data = [ + {'input': torch.zeros(1), 'target': torch.zeros(1)} + for (i, t) in zip(samples[:-1], samples[1:]) + ] return all_data def random_batch(self): From e433bf4a6fd9fd3b71d5afaa82ed369ea6dcf2b7 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Wed, 1 Mar 2023 10:56:15 +0000 Subject: [PATCH 52/81] some tweaks with metirc --- instructor/real_data/instructor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 13db8e30..9fbc3d71 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -69,20 +69,20 @@ def __init__(self, opt): self.clas_opt = None # Metrics - # bleu, more-bettter, changes in range 0.3 - 0.4, will have relatively high weight + # bleu, more-better, changes in range 0.4 - 0.6, will have relatively high weight self.bleu = BLEU('BLEU', weight=3, gram=3, if_use=cfg.use_bleu) # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) # self-bleu, less-bettter, changes in range 0.7 - 0.9, will have relatively high weight - self.self_bleu = BLEU('Self-BLEU', weight=-2, gram=3, if_use=cfg.use_self_bleu) + self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) # class-acc, more-bettter, changes in range 0.7 - 1.0, moderate weight self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) # IOC, less-bettter, changes in range 0.8 - 2.0, smaller weight self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data.tokens) - # nll_oracle, less-bettter, changes in range -0.1 - 0.5, moderate weight - self.nll_oracle = GPTNLL(weight=-1, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) + # nll_oracle, less-bettter, changes in range -0.1 - 0.6, moderate weight + self.nll_oracle = GPTNLL(weight=-2, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) # perplexity, less-bettter, changes in range 3 - 4, moderate weight self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] @@ -248,7 +248,8 @@ def cal_metrics(self, fmt_str=False): if fmt_str: pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] - pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") + # added magic number 3 to make overall score positive and more readable + pp_metrics.append(f"Overal_score: {3 + sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") return "\n" + "\n".join(pp_metrics) return [metric.get_score() for metric in self.all_metrics] From 620dba102142aa8bc944b2291eb28b31a1391c28 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 2 Mar 2023 14:49:34 +0000 Subject: [PATCH 53/81] addin wandb logger for metrics --- instructor/real_data/instructor.py | 32 +++++++++++++++++++----------- main.py | 12 +++++++++++ metrics/dummy.py | 13 ++++++++++++ 3 files changed, 45 insertions(+), 12 deletions(-) create mode 100644 metrics/dummy.py diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 9fbc3d71..9c7f54d7 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -10,6 +10,7 @@ import numpy as np import torch import torch.nn as nn +import wandb import config as cfg from metrics.bleu import BLEU @@ -18,6 +19,7 @@ from metrics.gpt_nll import GPTNLL from metrics.nll import NLL from metrics.ppl import PPL +from metrics.dummy import Dummy from utils.cat_data_loader import CatClasDataIter from utils.data_loader import GenDataIter from utils.helpers import Signal, create_logger, get_fixed_temperature @@ -71,9 +73,9 @@ def __init__(self, opt): # Metrics # bleu, more-better, changes in range 0.4 - 0.6, will have relatively high weight self.bleu = BLEU('BLEU', weight=3, gram=3, if_use=cfg.use_bleu) - # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight + # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight + # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) # self-bleu, less-bettter, changes in range 0.7 - 0.9, will have relatively high weight self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) @@ -82,10 +84,12 @@ def __init__(self, opt): # IOC, less-bettter, changes in range 0.8 - 2.0, smaller weight self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data.tokens) # nll_oracle, less-bettter, changes in range -0.1 - 0.6, moderate weight - self.nll_oracle = GPTNLL(weight=-2, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) - # perplexity, less-bettter, changes in range 3 - 4, moderate weight + self.nll_oracle = GPTNLL(weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) + # perplexity, less-bettter, changes in range 3 - 4, moderate weight (not in use) self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl] + # dummy, add constant value to overall score + self.dummy = Dummy(weight=1, value=5, if_use=True) + self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl, self.dummy] def _run(self): print('Nothing to run in Basic Instructor!') @@ -246,11 +250,12 @@ def cal_metrics(self, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) + metrics = {metric.name: metric.get_score() for metric in self.all_metrics} + metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + wandb.log(metrics) + if fmt_str: - pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] - # added magic number 3 to make overall score positive and more readable - pp_metrics.append(f"Overal_score: {3 + sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") - return "\n" + "\n".join(pp_metrics) + return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): @@ -267,10 +272,13 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) + metrics = {"label_i": label_i}) + metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) + metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + wandb.log(metrics) + if fmt_str: - pp_metrics = [f"{metric.name} = {metric.get_score()}" for metric in self.all_metrics] - pp_metrics.append(f"Overal_score: {sum(metric.weight * metric.get_score() for metric in self.all_metrics)}") - return "\n" + f"label: {label_i}\n" + "\n".join(pp_metrics) + return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): diff --git a/main.py b/main.py index e535374b..15e97abc 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ import argparse import torch import numpy as np +import wandb import config as cfg from utils.text_process import load_test_dict, text_process @@ -191,8 +192,19 @@ def program_config(parser): 'cat_fixemgan': FixemGANInstructor } + + # start a new wandb run to track this script + wandb.init( + # set the wandb project where this run will be logged + project="TextGAN", + # track hyperparameters and run metadata + config=vars(opt) + ) + inst = instruction_dict[cfg.run_model](opt) if not cfg.if_test: inst._run() else: inst._test() + + wandb.finish() diff --git a/metrics/dummy.py b/metrics/dummy.py new file mode 100644 index 00000000..30a82a68 --- /dev/null +++ b/metrics/dummy.py @@ -0,0 +1,13 @@ +from metrics.basic import Metrics + + +class Dummy(Metrics): + """ + Dummy score to make Overal score positive and easy to read + """ + def __init__(self, name=None, weight=1, value=5, if_use=True): + super(Dummy, self).__init__('Dummy', weight, if_use) + self.value = 5 + + def calculate_metric(self): + return self.value From f09a684500f4e9cac6f889401a6556490b6c14c2 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 09:44:16 +0000 Subject: [PATCH 54/81] refacotoring metrics --- metrics/basic.py | 11 ++++++----- metrics/bleu.py | 3 +-- metrics/clas_acc.py | 3 +-- metrics/gpt_nll.py | 3 +-- metrics/ioc.py | 3 +-- metrics/nll.py | 3 +-- metrics/ppl.py | 3 +-- 7 files changed, 12 insertions(+), 17 deletions(-) diff --git a/metrics/basic.py b/metrics/basic.py index b5a4b8af..6f69b195 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -31,13 +31,14 @@ def get_score(self): self.metric_value_with_current_state = self.calculate_metric() return self.metric_value_with_current_state - @abstractmethod + def reset(*args, **kwargs): + self.metric_value_with_current_state = None + self._reset(*args, **kwargs) + + @abstractmethod def calculate_metric(self): pass @abstractmethod - def reset(self): - pass - def _reset(self): - self.metric_value_with_current_state = None + pass diff --git a/metrics/bleu.py b/metrics/bleu.py index 432ea08a..5a720aba 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -35,8 +35,7 @@ def __init__(self, name=None, weight=1, test_text=None, real_text=None, gram=3, self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset - def reset(self, test_text=None, real_text=None): - self._reset() + def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text self.real_text = real_text if real_text else self.real_text diff --git a/metrics/clas_acc.py b/metrics/clas_acc.py index c2b5c229..b04e1b80 100644 --- a/metrics/clas_acc.py +++ b/metrics/clas_acc.py @@ -21,8 +21,7 @@ def __init__(self, weight, if_use=True, gpu=True): self.data_loader = None self.gpu = gpu - def reset(self, model=None, data_loader=None): - self._reset() + def _reset(self, model=None, data_loader=None): self.model = model self.data_loader = data_loader diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 33acb33c..a3a254f1 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -27,8 +27,7 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru print(f'dataset NLL based on GPT2 is {self.real_text_nll}') print('GPT2 as oracle metric will be calculated relative to this value') - def reset(self, test_text=None, real_text=None): - self._reset() + def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text self.real_text_nll = self.calcualte_NLL(real_text) if real_text else self.real_text_nll diff --git a/metrics/ioc.py b/metrics/ioc.py index eb995f3e..a188cd8c 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -19,8 +19,7 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.reference = None self.is_first = True - def reset(self, test_text=None, real_text=None): - self._reset() + def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc diff --git a/metrics/nll.py b/metrics/nll.py index f8323ed2..c0874ac4 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -34,8 +34,7 @@ def calculate_metric(self): self.criterion, self.gpu) return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) - def reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): - self._reset() + def _reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): self.model = model self.data_loader = data_loader self.label_i = label_i diff --git a/metrics/ppl.py b/metrics/ppl.py index 16e24b6f..7cafabdf 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -39,8 +39,7 @@ def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False): self.train_data = train_data self.test_data = test_data - def reset(self, gen_tokens=None): - self._reset() + def _reset(self, gen_tokens=None): self.gen_tokens = gen_tokens def calculate_metric(self): From 7d87b0f90801a19e6077521a0a0809077f99bcea Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 09:44:27 +0000 Subject: [PATCH 55/81] typo --- instructor/real_data/instructor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 9c7f54d7..4f298e78 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -272,7 +272,7 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) - metrics = {"label_i": label_i}) + metrics = {"label_i": label_i} metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) wandb.log(metrics) From 3772f069f09798cfde685c02698d9acaaea1f820 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 12:53:02 +0000 Subject: [PATCH 56/81] solving typos --- instructor/oracle_data/catgan_instructor.py | 2 +- instructor/oracle_data/dgsan_instructor.py | 2 +- instructor/oracle_data/fixem_instructor.py | 6 ------ instructor/oracle_data/instructor.py | 2 +- instructor/real_data/instructor.py | 2 ++ main.py | 1 + models/{generators => }/Oracle.py | 0 models/generators/FixemGAN_G.py | 3 +-- utils/cat_data_loader.py | 7 +++++-- 9 files changed, 12 insertions(+), 13 deletions(-) rename models/{generators => }/Oracle.py (100%) diff --git a/instructor/oracle_data/catgan_instructor.py b/instructor/oracle_data/catgan_instructor.py index efc2ae60..04d051b9 100644 --- a/instructor/oracle_data/catgan_instructor.py +++ b/instructor/oracle_data/catgan_instructor.py @@ -19,7 +19,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor from metrics.nll import NLL -from models.descriminators.CatGAN_D import CatGAN_D +from models.discriminators.CatGAN_D import CatGAN_D from models.generators.CatGAN_G import CatGAN_G from models.Oracle import Oracle from utils.cat_data_loader import CatGenDataIter diff --git a/instructor/oracle_data/dgsan_instructor.py b/instructor/oracle_data/dgsan_instructor.py index 8a810eb5..56d8a1aa 100644 --- a/instructor/oracle_data/dgsan_instructor.py +++ b/instructor/oracle_data/dgsan_instructor.py @@ -16,7 +16,7 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from models.generatos.DGSAN_G import DGSAN_G +from models.generators.DGSAN_G import DGSAN_G from utils.data_loader import GenDataIter from utils.helpers import create_oracle diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 239eb5e7..dafb5ec5 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -24,12 +24,6 @@ from models.generators.FixemGAN_G import Generator from models.discriminators.FixemGAN_D import Discriminator - -# TO DO: -# 4. save? or save each 10 epochs -# 10. cat_oracle -# 12. class accuracy - # afterwards: # check target real/fake to be right (Uniform or const) # random data portion generator - data supplier sample from randomint diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 452626c5..0d1f36b9 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -14,7 +14,7 @@ import config as cfg from metrics.nll import NLL -from models.generators.Oracle import Oracle +from models.Oracle import Oracle from utils.data_loader import GenDataIter from utils.data_utils import create_multi_oracle from utils.helpers import Signal, create_logger, create_oracle, get_fixed_temperature diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 4f298e78..d1f730e5 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -271,6 +271,8 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) self.ppl.reset(gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) metrics = {"label_i": label_i} metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) diff --git a/main.py b/main.py index 15e97abc..cae49a51 100644 --- a/main.py +++ b/main.py @@ -131,6 +131,7 @@ def program_config(parser): if __name__ == '__main__': #seed everything torch.manual_seed(0) + torch.use_deterministic_algorithms(True) random.seed(0) np.random.seed(0) diff --git a/models/generators/Oracle.py b/models/Oracle.py similarity index 100% rename from models/generators/Oracle.py rename to models/Oracle.py diff --git a/models/generators/FixemGAN_G.py b/models/generators/FixemGAN_G.py index eb3c0b4f..c83dc7c1 100644 --- a/models/generators/FixemGAN_G.py +++ b/models/generators/FixemGAN_G.py @@ -129,8 +129,7 @@ def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): def forward(self, noise, target_labels): target_labels = torch.nn.functional.one_hot(target_labels, num_classes=cfg.k_label) - x = self.main([noise, target_labels]) - return x + return self.main([noise, target_labels]) def sample(self, num_samples, batch_size, label_i = 'random', start_letter=cfg.start_letter): noise = create_noise(num_samples, self.noise_size, cfg.k_label) diff --git a/utils/cat_data_loader.py b/utils/cat_data_loader.py index 63bd803c..07721db5 100644 --- a/utils/cat_data_loader.py +++ b/utils/cat_data_loader.py @@ -4,7 +4,7 @@ # @FileName : cat_data_loader.py # @Time : Created at 2019-05-31 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import random @@ -135,7 +135,10 @@ def prepare(samples_list, given_target=None, detach=True, gpu=False): - inp: sentences - target: label index, 0-label_0, 1-label_1, ..., k-label_k """ - if len(samples_list) == 1 and given_target is not None: + if type(samples_list[0][0][0]) == str: # directly generated text + inp = torch.zeros(1) + target = torch.zeros(1) + elif len(samples_list) == 1 and given_target is not None: inp = samples_list[0] if detach: inp = inp.detach() From 9d122e2626dc6638bdd896baa210f6efac91aa46 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 12:55:06 +0000 Subject: [PATCH 57/81] solving error in oracle model --- instructor/real_data/instructor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index d1f730e5..e298b6ff 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -44,6 +44,8 @@ def __init__(self, opt): self.word2idx_dict, self.idx2word_dict = {}, {} # Dataloader + self.train_data = None + self.test_data = None try: self.train_data = GenDataIter(cfg.train_data) self.test_data = GenDataIter(cfg.test_data, if_test_data=True) From beac393aaf319b2213a4c93caa8389430139c503 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 12:57:09 +0000 Subject: [PATCH 58/81] adding logging to oracle instructor --- instructor/oracle_data/instructor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 0d1f36b9..1b9c1bdf 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -11,6 +11,7 @@ import os import torch import torch.nn as nn +import wandb import config as cfg from metrics.nll import NLL @@ -192,6 +193,12 @@ def cal_metrics(self, fmt_str=False): self.nll_gen.reset(self.gen, self.oracle_data.loader) self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) + + metrics = {metric.name: metric.get_score() for metric in self.all_metrics} + metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + wandb.log(metrics) if fmt_str: return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) @@ -208,6 +215,13 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.nll_gen.reset(self.gen, self.oracle_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) + self.ioc.reset(test_text=gen_tokens) + self.nll_oracle.reset(test_text=gen_tokens) + + metrics = {"label_i": label_i} + metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) + metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + wandb.log(metrics) if fmt_str: return f'label: {label_i}' + ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) From dfc7639b82ed6c485e125cd50e662e6a538da4bd Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 13:03:01 +0000 Subject: [PATCH 59/81] err in metrics --- instructor/real_data/instructor.py | 4 ++-- metrics/basic.py | 2 +- metrics/gpt_nll.py | 4 ++-- metrics/ioc.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index e298b6ff..87233a1c 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -84,9 +84,9 @@ def __init__(self, opt): # class-acc, more-bettter, changes in range 0.7 - 1.0, moderate weight self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) # IOC, less-bettter, changes in range 0.8 - 2.0, smaller weight - self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data.tokens) + self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data) # nll_oracle, less-bettter, changes in range -0.1 - 0.6, moderate weight - self.nll_oracle = GPTNLL(weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data.tokens) + self.nll_oracle = GPTNLL(weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data) # perplexity, less-bettter, changes in range 3 - 4, moderate weight (not in use) self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) # dummy, add constant value to overall score diff --git a/metrics/basic.py b/metrics/basic.py index 6f69b195..a72a8488 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -35,7 +35,7 @@ def reset(*args, **kwargs): self.metric_value_with_current_state = None self._reset(*args, **kwargs) - @abstractmethod + @abstractmethod def calculate_metric(self): pass diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index a3a254f1..ebf1383b 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -23,13 +23,13 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.model = GPT2LMHeadModel.from_pretrained("gpt2") print('Calculating dataset NLL') - self.real_text_nll = self.calcualte_NLL(random.sample(real_text, 500)) if real_text else None + self.real_text_nll = self.calcualte_NLL(random.sample(real_text.tokens, 500)) if real_text else None print(f'dataset NLL based on GPT2 is {self.real_text_nll}') print('GPT2 as oracle metric will be calculated relative to this value') def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text - self.real_text_nll = self.calcualte_NLL(real_text) if real_text else self.real_text_nll + self.real_text_nll = self.calcualte_NLL(real_text.tokens) if real_text else self.real_text_nll def calculate_metric(self): """Get gpt2 NLL score difference with dataset NLL.""" diff --git a/metrics/ioc.py b/metrics/ioc.py index a188cd8c..f69bcae6 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -14,14 +14,14 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.if_use = if_use self.test_text = test_text - self.real_text_ioc = self.calculate_ioc(real_text) if real_text else None + self.real_text_ioc = self.calculate_ioc(real_text.tokens) if real_text else None print(f'Dataset Index of coincidence: {self.real_text_ioc}') self.reference = None self.is_first = True def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text - self.real_text_ioc = self.get_ioc(real_text) if real_text else self.real_text_ioc + self.real_text_ioc = self.get_ioc(real_text.tokens) if real_text else self.real_text_ioc def calculate_metric(self): return self.calculate_ioc(self.test_text) / self.real_text_ioc From b252e22549294cb25e0e212071998981bf0c22fc Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 14:39:01 +0000 Subject: [PATCH 60/81] workaround oracle generator --- config.py | 4 -- instructor/oracle_data/fixem_instructor.py | 6 +-- instructor/oracle_data/instructor.py | 24 ++++++++-- instructor/real_data/instructor.py | 10 ++-- metrics/gpt_nll.py | 7 +-- metrics/ioc.py | 3 +- models/Oracle.py | 5 ++ run/run_fixem.py | 10 ++-- utils/data_loader.py | 55 +++++++++------------- utils/helpers.py | 7 --- 10 files changed, 62 insertions(+), 69 deletions(-) diff --git a/config.py b/config.py index 151f5b9d..51dec6c6 100644 --- a/config.py +++ b/config.py @@ -204,10 +204,6 @@ tips = '' -if samples_num == 5000 or samples_num == 2000: - assert 'c' in run_model, 'warning: samples_num={}, run_model={}'.format(samples_num, run_model) - - # Init settings according to parser def init_param(opt): global run_model, model_type, loss_type, CUDA, device, data_shuffle, samples_num, vocab_size, \ diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index dafb5ec5..dba43bcc 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -28,7 +28,7 @@ # check target real/fake to be right (Uniform or const) # random data portion generator - data supplier sample from randomint -class FixemGANInstructor(RealDataFixemGANInstructor, BasicInstructor): +class FixemGANInstructor(BasicInstructor, RealDataFixemGANInstructor): def __init__(self, opt): self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len,cfg.padding_idx, gpu=cfg.CUDA) if cfg.oracle_pretrain: @@ -42,10 +42,6 @@ def __init__(self, opt): super().__init__(opt) - - # Metrics - self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) - def build_embedding(self): # train embedding on available dataset or oracle self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 1b9c1bdf..a60732af 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -14,7 +14,13 @@ import wandb import config as cfg +from metrics.bleu import BLEU +from metrics.clas_acc import ACC +from metrics.ioc import IOC +from metrics.gpt_nll import GPTNLL from metrics.nll import NLL +from metrics.ppl import PPL +from metrics.dummy import Dummy from models.Oracle import Oracle from utils.data_loader import GenDataIter from utils.data_utils import create_multi_oracle @@ -54,11 +60,19 @@ def __init__(self, opt): self.dis_criterion = nn.CrossEntropyLoss() # Metrics - self.nll_oracle = NLL('NLL_oracle', if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) - self.nll_gen = NLL('NLL_gen', if_use=cfg.use_nll_gen, gpu=cfg.CUDA) - self.nll_div = NLL('NLL_div', if_use=cfg.use_nll_div, gpu=cfg.CUDA) - self.self_bleu = BLEU('Self-BLEU', gram=3, if_use=cfg.use_self_bleu) - self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div, self.self_bleu] + # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight + self.nll_oracle = NLL('NLL_oracle', weight=1, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) + # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) + self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) + self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight + self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) + # IOC, less-better, changes in range 0.8 - 2.0, smaller weight + self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.oracle_data) + # dummy, add constant value to overall score + self.dummy = Dummy(weight=1, value=5, if_use=True) + self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.dummy] def _run(self): print('Nothing to run in Basic Instructor!') diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 87233a1c..923ebc76 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -79,15 +79,15 @@ def __init__(self, opt): self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) - # self-bleu, less-bettter, changes in range 0.7 - 0.9, will have relatively high weight + # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) - # class-acc, more-bettter, changes in range 0.7 - 1.0, moderate weight + # class-acc, more-better, changes in range 0.7 - 1.0, moderate weight self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) - # IOC, less-bettter, changes in range 0.8 - 2.0, smaller weight + # IOC, less-better, changes in range 0.8 - 2.0, smaller weight self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data) - # nll_oracle, less-bettter, changes in range -0.1 - 0.6, moderate weight + # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight self.nll_oracle = GPTNLL(weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data) - # perplexity, less-bettter, changes in range 3 - 4, moderate weight (not in use) + # perplexity, less-better, changes in range 3 - 4, moderate weight (not in use) self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) # dummy, add constant value to overall score self.dummy = Dummy(weight=1, value=5, if_use=True) diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index ebf1383b..6d2b8b28 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -24,8 +24,9 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.model = GPT2LMHeadModel.from_pretrained("gpt2") print('Calculating dataset NLL') self.real_text_nll = self.calcualte_NLL(random.sample(real_text.tokens, 500)) if real_text else None - print(f'dataset NLL based on GPT2 is {self.real_text_nll}') - print('GPT2 as oracle metric will be calculated relative to this value') + if real_text_nll: + print(f'dataset NLL based on GPT2 is {self.real_text_nll}') + print('GPT2 as oracle metric will be calculated relative to this value') def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text else self.test_text @@ -36,7 +37,7 @@ def calculate_metric(self): return self.calcualte_NLL(self.test_text) - self.real_text_nll def calcualte_NLL(self, messages): - if type(messages[0]) == list: #we received list of tokens + if type(messages[0]) == list: # we received list of tokens messages = [' '.join(msg) for msg in messages] all_logits = [] diff --git a/metrics/ioc.py b/metrics/ioc.py index f69bcae6..3a5e0ed8 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -15,7 +15,8 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.if_use = if_use self.test_text = test_text self.real_text_ioc = self.calculate_ioc(real_text.tokens) if real_text else None - print(f'Dataset Index of coincidence: {self.real_text_ioc}') + if real_text_ioc: + print(f'Dataset Index of coincidence: {self.real_text_ioc}') self.reference = None self.is_first = True diff --git a/models/Oracle.py b/models/Oracle.py index d637559a..901117c7 100644 --- a/models/Oracle.py +++ b/models/Oracle.py @@ -19,3 +19,8 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i # initialise oracle network with N(0,1) # otherwise variance of initialisation is very small => high NLL for loader sampled from the same model self.init_oracle() + + def init_oracle(self): + for param in self.parameters(): + if param.requires_grad: + torch.nn.init.normal_(param, mean=0, std=1) diff --git a/run/run_fixem.py b/run/run_fixem.py index bc20c2ed..1d57d84f 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -31,15 +31,16 @@ scriptname = 'main.py' # ===Program=== -# EvoGAN: General text generation model +# FixemGAN: General text generation model if_test = int(False) -run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan'] +run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan'] k_label = 2 CUDA = int(True) batch_size = 32 noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 +samples_num = 100_000 # training samples size tips = '{} experiments' # ===Oracle or Real=== @@ -52,13 +53,10 @@ w2v_samples_num = 5_000_000 vocab_size = 5000 -# ===CatGAN Param=== -loss_type = 'fixem' -oracle_train_samples_num = 100_000 - # === Basic Param === data_shuffle = int(False) model_type = 'fixem' +loss_type = 'fixem' gen_init = 'truncated_normal' dis_init = 'uniform' batch_size = 64 diff --git a/utils/data_loader.py b/utils/data_loader.py index 29fea025..7ae5a2ce 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -37,10 +37,9 @@ def __len__(self): class GenDataIter: def __init__(self, samples, if_test_data=False, shuffle=None): - self.batch_size = cfg.batch_size - self.max_seq_len = cfg.max_seq_len - self.start_letter = cfg.start_letter + self.samples = samples self.shuffle = cfg.data_shuffle if not shuffle else shuffle + if cfg.if_real_data: self.word2idx_dict, self.idx2word_dict = load_dict(cfg.dataset) if if_test_data: # used for the classifier @@ -48,7 +47,7 @@ def __init__(self, samples, if_test_data=False, shuffle=None): self.loader = DataLoader( dataset=GANDataset(self.__read_data__(samples)), - batch_size=self.batch_size, + batch_size=cfg.batch_size, shuffle=self.shuffle, drop_last=True) @@ -59,18 +58,12 @@ def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - all_data = None - if isinstance(samples, torch.Tensor): # Tensor - inp, target = self.prepare(samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + if isinstance(samples, list): # list of strings + return [] elif isinstance(samples, str): # filename - inp, target = self.load_data(samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] - elif isinstance(samples, list): # list of tokens, required for generator NLL - all_data = [ - {'input': torch.zeros(1), 'target': torch.zeros(1)} - for (i, t) in zip(samples[:-1], samples[1:]) - ] + samples = self.load_file_indexed(samples) + inp, target = self.prepare_for_NLL(samples) + all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -81,8 +74,16 @@ def random_batch(self): def _all_data_(self, col): return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + @property + def tokens(self): + """Returns samples in form of list of tensors, if input tensor, + or list of tokens in case if input string.""" + if type(self.samples[0]) == str: # we have list of strings + return [smpl.split() for smpl in self.samples] + return samples + @staticmethod - def prepare(samples, gpu=False): + def prepare_for_NLL(samples, gpu=False): """Add start_letter to samples as inp, target same as samples""" inp = torch.zeros(samples.size()).long() target = samples @@ -93,30 +94,23 @@ def prepare(samples, gpu=False): return inp.cuda(), target.cuda() return inp, target - def load_data(self, filename): + def load_file_indexed(self, filename): """Load real data from local file""" - self.tokens = get_tokenlized(filename) - samples_index = tokens_to_tensor(self.tokens, self.word2idx_dict) - return self.prepare(samples_index) + tokens = get_tokenlized(filename) + return tokens_to_tensor(tokens, self.word2idx_dict) class DisDataIter: def __init__(self, pos_samples, neg_samples, shuffle=None): - self.batch_size = cfg.batch_size - self.max_seq_len = cfg.max_seq_len - self.start_letter = cfg.start_letter self.shuffle = cfg.data_shuffle if not shuffle else shuffle self.loader = DataLoader( dataset=GANDataset(self.__read_data__(pos_samples, neg_samples)), - batch_size=self.batch_size, + batch_size=cfg.batch_size, shuffle=self.shuffle, drop_last=True) def __read_data__(self, pos_samples, neg_samples): - """ - input: same as target, but start with start_letter. - """ inp, target = self.prepare(pos_samples, neg_samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] return all_data @@ -148,16 +142,11 @@ def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): for label, tokens in zip(labels, tokenized) if all(token in w2v.wv for token in tokens) ]) - self.labels = torch.tensor(labels, dtype=int) - self.tokenized = np.array(tokenized) - self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size - self.w2v = w2v - self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) print('dataset random texts examples\n', '\n'.join([txt for txt in self.texts][:5])) @@ -200,5 +189,5 @@ def __iter__(self): def __len__(self): return len(self.tokenized) - def is_this_message_in_dataset(self, text): + def is_message_in_dataset(self, text): return text in self.texts diff --git a/utils/helpers.py b/utils/helpers.py index bd681bb7..7183e2c0 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -86,13 +86,6 @@ def create_oracle(): f.write(" ".join(str(int(idx)) for idx in sample)) f.write("\n") - # moderate for training for W2V - train_samples = oracle.sample(cfg.oracle_train_samples_num, 4 * cfg.batch_size) - with open(cfg.train_data, 'w') as f: - for sample in tqdm(train_samples): - f.write(" ".join(str(int(idx)) for idx in sample)) - f.write("\n") - oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) From 41bfc4427958a00107041352206d26d741d50710 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Fri, 3 Mar 2023 14:57:22 +0000 Subject: [PATCH 61/81] workaround oracle metrics --- metrics/gpt_nll.py | 2 +- metrics/ioc.py | 2 +- run/run_fixem.py | 1 - utils/data_loader.py | 21 +++++++++++---------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 6d2b8b28..df12c6a3 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -24,7 +24,7 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.model = GPT2LMHeadModel.from_pretrained("gpt2") print('Calculating dataset NLL') self.real_text_nll = self.calcualte_NLL(random.sample(real_text.tokens, 500)) if real_text else None - if real_text_nll: + if self.real_text_nll: print(f'dataset NLL based on GPT2 is {self.real_text_nll}') print('GPT2 as oracle metric will be calculated relative to this value') diff --git a/metrics/ioc.py b/metrics/ioc.py index 3a5e0ed8..ad641e17 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -15,7 +15,7 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.if_use = if_use self.test_text = test_text self.real_text_ioc = self.calculate_ioc(real_text.tokens) if real_text else None - if real_text_ioc: + if self.real_text_ioc: print(f'Dataset Index of coincidence: {self.real_text_ioc}') self.reference = None self.is_first = True diff --git a/run/run_fixem.py b/run/run_fixem.py index 1d57d84f..b26cb7c2 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -104,7 +104,6 @@ '--batches_per_epoch', batches_per_epoch, '--noise_size', noise_size, '--target_len', target_len[job_id], - '--oracle_train_samples_num', oracle_train_samples_num, # Generator '--generator_complexity', generator_complexity, diff --git a/utils/data_loader.py b/utils/data_loader.py index 7ae5a2ce..fd7598ea 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -37,7 +37,11 @@ def __len__(self): class GenDataIter: def __init__(self, samples, if_test_data=False, shuffle=None): - self.samples = samples + if type(samples) == str: # we received filename + self.samples = get_tokenlized(samples) + else: + self.samples = samples + self.shuffle = cfg.data_shuffle if not shuffle else shuffle if cfg.if_real_data: @@ -46,7 +50,7 @@ def __init__(self, samples, if_test_data=False, shuffle=None): self.word2idx_dict, self.idx2word_dict = load_test_dict(cfg.dataset) self.loader = DataLoader( - dataset=GANDataset(self.__read_data__(samples)), + dataset=GANDataset(self.__read_data__(self.samples)), batch_size=cfg.batch_size, shuffle=self.shuffle, drop_last=True) @@ -58,10 +62,12 @@ def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - if isinstance(samples, list): # list of strings + if isinstance(samples[0], str): # list of strings + # we directly generated string, skip NLL return [] - elif isinstance(samples, str): # filename - samples = self.load_file_indexed(samples) + if isinstance(samples[0], list) and isinstance(samples[0], str): + # need to transform to indexes + samples = tokens_to_tensor(tokens, self.word2idx_dict) inp, target = self.prepare_for_NLL(samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] return all_data @@ -94,11 +100,6 @@ def prepare_for_NLL(samples, gpu=False): return inp.cuda(), target.cuda() return inp, target - def load_file_indexed(self, filename): - """Load real data from local file""" - tokens = get_tokenlized(filename) - return tokens_to_tensor(tokens, self.word2idx_dict) - class DisDataIter: def __init__(self, pos_samples, neg_samples, shuffle=None): From 26071abbba4cbbe46b74d369baaf2d365d53232d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Sat, 4 Mar 2023 07:57:12 +0000 Subject: [PATCH 62/81] cleaning for oracle metrics --- config.py | 6 ++--- instructor/oracle_data/instructor.py | 37 ++++++++++++++-------------- instructor/real_data/instructor.py | 13 +++++----- main.py | 2 -- metrics/basic.py | 2 +- metrics/bleu.py | 4 +-- metrics/gpt_nll.py | 4 +-- metrics/ioc.py | 8 ++++-- models/Oracle.py | 2 ++ models/generators/generator.py | 5 ---- run/run_catgan.py | 2 +- utils/data_loader.py | 13 ++++++---- utils/text_process.py | 8 ++++-- 13 files changed, 55 insertions(+), 51 deletions(-) diff --git a/config.py b/config.py index 51dec6c6..ba348f1b 100644 --- a/config.py +++ b/config.py @@ -46,7 +46,6 @@ noise_size = 1000 max_epochs = 20 target_len = 40 -oracle_train_samples_num = 100_000 # ===Embedding=== w2v_embedding_size = 100 @@ -74,7 +73,7 @@ temperature = 1 # ===Basic Train=== -samples_num = 100 # 10000, mr15: 2000, +samples_num = 10000 #, mr15: 2000, small_sample_num = 20 # used for self-blue MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 @@ -219,7 +218,7 @@ def init_param(opt): multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, oracle_train_samples_num + generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -252,7 +251,6 @@ def init_param(opt): noise_size = opt.noise_size max_epochs = opt.max_epochs target_len = opt.target_len - oracle_train_samples_num = opt.oracle_train_samples_num samples_num = opt.samples_num vocab_size = opt.vocab_size diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index a60732af..655d5e93 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -25,7 +25,7 @@ from utils.data_loader import GenDataIter from utils.data_utils import create_multi_oracle from utils.helpers import Signal, create_logger, create_oracle, get_fixed_temperature -from utils.text_process import write_tensor +from utils.text_process import write_tensor, tensor_to_tokens class BasicInstructor: @@ -61,7 +61,7 @@ def __init__(self, opt): # Metrics # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight - self.nll_oracle = NLL('NLL_oracle', weight=1, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) + self.nll_oracle = NLL('NLL_oracle', weight=-3, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) @@ -182,15 +182,15 @@ def show_config(self): def sample_for_metrics(self): eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_data = GenDataIter(eval_samples) - gen_tokens = eval_samples - gen_tokens_s = self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size) + gen_tokens = tensor_to_tokens(eval_samples) + gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size)) return gen_data, gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size, label_i=label_i) gen_data = GenDataIter(eval_samples) - gen_tokens = eval_samples - gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens = tensor_to_tokens(eval_samples) + gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i)) return gen_data, gen_tokens, gen_tokens_s def cal_metrics(self, fmt_str=False): @@ -200,7 +200,7 @@ def cal_metrics(self, fmt_str=False): """ with torch.no_grad(): # Prepare data for evaluation - gen_data, gen_tokens, gen_tokens_s = sample_for_metrics() + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) @@ -208,21 +208,20 @@ def cal_metrics(self, fmt_str=False): self.nll_div.reset(self.gen, gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.ioc.reset(test_text=gen_tokens) - self.nll_oracle.reset(test_text=gen_tokens) metrics = {metric.name: metric.get_score() for metric in self.all_metrics} metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) wandb.log(metrics) if fmt_str: - return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) + return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): assert type(label_i) == int, 'missing label' with torch.no_grad(): # Prepare data for evaluation - gen_data, gen_tokens, gen_tokens_s = sample_for_metrics_with_label() + gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics_with_label() # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader, label_i) @@ -232,23 +231,23 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) - metrics = {"label_i": label_i} - metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) - metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + metrics = {f"label {label_i}_{metric.name}": metric.get_score() for metric in self.all_metrics} + metrics.update({f"label {label_i} Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) wandb.log(metrics) if fmt_str: - return f'label: {label_i}' + ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) - return [metric.get_score() for metric in self.all_metrics] + return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) + return metrics def comb_metrics(self, fmt_str=False): all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] - all_scores = np.array(all_scores).T.tolist() # each row for each metric if fmt_str: - return ', '.join(['%s = %s' % (metric.name, score) - for (metric, score) in zip(self.all_metrics, all_scores)]) - return all_scores + return ', '.join([ + f'{name} = {[scores[name] for scores in all_scores]}' + for name in all_scores[0] + ]) + return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 923ebc76..ef7eb218 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -276,9 +276,8 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) - metrics = {"label_i": label_i} - metrics.update({metric.name: metric.get_score() for metric in self.all_metrics}) - metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + metrics = {f"label {label_i}_{metric.name}": metric.get_score() for metric in self.all_metrics} + metrics.update({f"label {label_i} Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) wandb.log(metrics) if fmt_str: @@ -287,11 +286,13 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): def comb_metrics(self, fmt_str=False): all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] - all_scores = np.array(all_scores).T.tolist() # each row for each metric if fmt_str: - return ', '.join([f"{metric.name} = {score}" for (metric, score) in zip(self.all_metrics, all_scores)]) - return all_scores + return ', '.join([ + f'{name} = {[scores[name] for scores in all_scores]}' + for name in all_scores[0] + ]) + return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" diff --git a/main.py b/main.py index cae49a51..b38473d6 100644 --- a/main.py +++ b/main.py @@ -52,7 +52,6 @@ def program_config(parser): parser.add_argument('--noise_size', default=cfg.noise_size, type=int) parser.add_argument('--max_epochs', default=cfg.max_epochs, type=int) parser.add_argument('--target_len', default=cfg.target_len, type=int) - parser.add_argument('--oracle_train_samples_num', default=cfg.oracle_train_samples_num, type=int) # Basic Train parser.add_argument('--samples_num', default=cfg.samples_num, type=int) @@ -131,7 +130,6 @@ def program_config(parser): if __name__ == '__main__': #seed everything torch.manual_seed(0) - torch.use_deterministic_algorithms(True) random.seed(0) np.random.seed(0) diff --git a/metrics/basic.py b/metrics/basic.py index a72a8488..6df50188 100644 --- a/metrics/basic.py +++ b/metrics/basic.py @@ -31,7 +31,7 @@ def get_score(self): self.metric_value_with_current_state = self.calculate_metric() return self.metric_value_with_current_state - def reset(*args, **kwargs): + def reset(self, *args, **kwargs): self.metric_value_with_current_state = None self._reset(*args, **kwargs) diff --git a/metrics/bleu.py b/metrics/bleu.py index 5a720aba..d0ccc290 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -36,8 +36,8 @@ def __init__(self, name=None, weight=1, test_text=None, real_text=None, gram=3, self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset def _reset(self, test_text=None, real_text=None): - self.test_text = test_text if test_text else self.test_text - self.real_text = real_text if real_text else self.real_text + self.test_text = test_text if test_text is not None else self.test_text + self.real_text = real_text if real_text is not None else self.real_text def get_reference(self): reference = self.real_text.copy() diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index df12c6a3..64de2404 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -29,8 +29,8 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru print('GPT2 as oracle metric will be calculated relative to this value') def _reset(self, test_text=None, real_text=None): - self.test_text = test_text if test_text else self.test_text - self.real_text_nll = self.calcualte_NLL(real_text.tokens) if real_text else self.real_text_nll + self.test_text = test_text if test_text is not None else self.test_text + self.real_text_nll = self.calcualte_NLL(real_text.tokens) if real_text is not None else self.real_text_nll def calculate_metric(self): """Get gpt2 NLL score difference with dataset NLL.""" diff --git a/metrics/ioc.py b/metrics/ioc.py index ad641e17..7e95360c 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -21,14 +21,18 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.is_first = True def _reset(self, test_text=None, real_text=None): - self.test_text = test_text if test_text else self.test_text - self.real_text_ioc = self.get_ioc(real_text.tokens) if real_text else self.real_text_ioc + self.test_text = test_text if test_text is not None else self.test_text + self.real_text_ioc = self.get_ioc(real_text.tokens) if real_text is not None else self.real_text_ioc def calculate_metric(self): return self.calculate_ioc(self.test_text) / self.real_text_ioc def calculate_ioc(self, tokenized_text): """Index Of coincidence: probability of 2 random tokens in text to equal.""" + tokenized_text = [ + [str(token) for token in tokens] + for tokens in tokenized_text + ] tokens = list(chain(*tokenized_text)) counts = Counter(tokens) total = sum(ni * (ni - 1) for ni in counts.values()) diff --git a/models/Oracle.py b/models/Oracle.py index 901117c7..4c3cbb8b 100644 --- a/models/Oracle.py +++ b/models/Oracle.py @@ -7,6 +7,8 @@ # @Description : # Copyrights (C) 2018. All Rights Reserved. +import torch + from models.generators.generator import LSTMGenerator diff --git a/models/generators/generator.py b/models/generators/generator.py index e9b45fbd..8db32e26 100644 --- a/models/generators/generator.py +++ b/models/generators/generator.py @@ -94,11 +94,6 @@ def init_params(self): elif cfg.gen_init == 'truncated_normal': truncated_normal_(param, std=stddev) - def init_oracle(self): - for param in self.parameters(): - if param.requires_grad: - torch.nn.init.normal_(param, mean=0, std=1) - def init_hidden(self, batch_size=cfg.batch_size): h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) diff --git a/run/run_catgan.py b/run/run_catgan.py index 69d4f08f..e86b9619 100644 --- a/run/run_catgan.py +++ b/run/run_catgan.py @@ -34,7 +34,7 @@ # CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) -run_model = ['evogan', 'catgan', 'catgan', 'catgan', 'evogan', 'evogan'] +run_model = ['evogan', 'catgan', 'catgan', 'evogan', 'evogan', 'evogan'] k_label = 2 CUDA = int(True) ora_pretrain = int(True) diff --git a/utils/data_loader.py b/utils/data_loader.py index fd7598ea..030f959c 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -62,12 +62,15 @@ def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - if isinstance(samples[0], str): # list of strings + if isinstance(samples[0], str) or isinstance(samples[0][0], str): # list of strings # we directly generated string, skip NLL - return [] - if isinstance(samples[0], list) and isinstance(samples[0], str): + return [ + {'input': i, 'target': t} + for i, t in zip(torch.zeros(2), torch.zeros(2)) + ] + if isinstance(samples[0], list): # need to transform to indexes - samples = tokens_to_tensor(tokens, self.word2idx_dict) + samples = tokens_to_tensor(samples, self.word2idx_dict) inp, target = self.prepare_for_NLL(samples) all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] return all_data @@ -86,7 +89,7 @@ def tokens(self): or list of tokens in case if input string.""" if type(self.samples[0]) == str: # we have list of strings return [smpl.split() for smpl in self.samples] - return samples + return list(self.samples) @staticmethod def prepare_for_NLL(samples, gpu=False): diff --git a/utils/text_process.py b/utils/text_process.py index 80090173..83a3435c 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -6,6 +6,7 @@ # @Blog : http://zhiweil.ml/ # @Description : # Copyrights (C) 2018. All Rights Reserved. +from typing import Dict, Optional import numpy as np import os @@ -127,7 +128,7 @@ def load_test_dict(dataset): return word2idx_dict, idx2word_dict -def tensor_to_tokens(tensor, dictionary): +def tensor_to_tokens(tensor, dictionary: Optional[Dict[torch.Tensor, str]] = None): """transform Tensor to word tokens""" tokens = [] for sent in tensor: @@ -135,7 +136,10 @@ def tensor_to_tokens(tensor, dictionary): for word in sent.tolist(): if word == cfg.padding_idx: break - sent_token.append(dictionary[str(word)]) + word = str(word) + if dictionary: + word = dictionary[word] + sent_token.append(word) tokens.append(sent_token) return tokens From 27c29d7c092477fd6dbcaef4e5a7f5d3e546d241 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 12:25:09 +0100 Subject: [PATCH 63/81] adding sweep.yml for parameters selection --- .gitignore | 6 +++++- main.py | 18 ++++++++++++++++-- sweep.yml | 21 +++++++++++++++++++++ utils/data_loader.py | 9 ++++----- 4 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 sweep.yml diff --git a/.gitignore b/.gitignore index e58a62b6..48a44234 100644 --- a/.gitignore +++ b/.gitignore @@ -109,4 +109,8 @@ venv.bak/ /site # mypy -.mypy_cache/ \ No newline at end of file +.mypy_cache/ + +# wandb logs +wandb/ +**/.watchman-cookie-* diff --git a/main.py b/main.py index b38473d6..a25c636a 100644 --- a/main.py +++ b/main.py @@ -9,9 +9,10 @@ from __future__ import print_function import random +import yaml import argparse -import torch +# import torch import numpy as np import wandb @@ -129,7 +130,7 @@ def program_config(parser): # MAIN if __name__ == '__main__': #seed everything - torch.manual_seed(0) + # torch.manual_seed(0) random.seed(0) np.random.seed(0) @@ -200,6 +201,19 @@ def program_config(parser): config=vars(opt) ) + import wandb + + # Example sweep configuration + with open('sweep.yml') as sweep_yml: + sweep_configuration = yaml.safe_load(sweep_yml) + print(sweep_configuration) + + sweep_id = wandb.sweep(sweep=sweep_configuration, project="project-name") + print(sweep_id) + print(opt) + + wandb.agent(sweep_id=sweep_id, function=function_name) + inst = instruction_dict[cfg.run_model](opt) if not cfg.if_test: inst._run() diff --git a/sweep.yml b/sweep.yml new file mode 100644 index 00000000..b02385f5 --- /dev/null +++ b/sweep.yml @@ -0,0 +1,21 @@ +program: train.py +method: bayes +metric: + goal: minimize + name: Overal_score +parameters: + oracle_train_samples_num: + value: 10000 + discriminator_complexity: + max: 1024 + min: 256 + distribution: int_uniform + generator_complexity: + max: 1536 + min: 384 + distribution: int_uniform + w2v_embedding_size: + values: + - 128 + - 256 + - 512 diff --git a/utils/data_loader.py b/utils/data_loader.py index 030f959c..fd2c6c15 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -37,10 +37,9 @@ def __len__(self): class GenDataIter: def __init__(self, samples, if_test_data=False, shuffle=None): - if type(samples) == str: # we received filename - self.samples = get_tokenlized(samples) - else: - self.samples = samples + self.samples = samples + if type(self.samples) == str: # we received filename + self.samples = get_tokenlized(self.samples) self.shuffle = cfg.data_shuffle if not shuffle else shuffle @@ -191,7 +190,7 @@ def __iter__(self): def __len__(self): - return len(self.tokenized) + return self.batches_per_epoch def is_message_in_dataset(self, text): return text in self.texts From 782a533738a8da34b19744818090ef3caeae1794 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 13:43:29 +0100 Subject: [PATCH 64/81] adding sweep.yml for parameters selection --- main.py | 41 ++++++++++++++++++++--------------------- sweep.yml | 10 ++++------ 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index a25c636a..9bb4a349 100644 --- a/main.py +++ b/main.py @@ -193,31 +193,30 @@ def program_config(parser): } - # start a new wandb run to track this script - wandb.init( - # set the wandb project where this run will be logged - project="TextGAN", - # track hyperparameters and run metadata - config=vars(opt) - ) - - import wandb - # Example sweep configuration with open('sweep.yml') as sweep_yml: sweep_configuration = yaml.safe_load(sweep_yml) - print(sweep_configuration) + print('sweep_configuration', sweep_configuration) - sweep_id = wandb.sweep(sweep=sweep_configuration, project="project-name") - print(sweep_id) - print(opt) + # sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") + sweep_id = "qdpnjvhf" + print('sweep_id', sweep_id) - wandb.agent(sweep_id=sweep_id, function=function_name) + def function_for_parameters_choice(): + run = wandb.init() # Initialize a new wandb run + config = run.config # Get the config dictionary for the current run + print('config', config) - inst = instruction_dict[cfg.run_model](opt) - if not cfg.if_test: - inst._run() - else: - inst._test() + # Update 'opt' with the hyperparameters from 'config' + for name, value in config.items(): + setattr(opt, name, value) + + inst = instruction_dict[cfg.run_model](opt) + if not cfg.if_test: + inst._run() + else: + inst._test() + + run.finish() # Make sure to finish the run - wandb.finish() + wandb.agent(sweep_id=sweep_id, function=function_for_parameters_choice) diff --git a/sweep.yml b/sweep.yml index b02385f5..3d77d567 100644 --- a/sweep.yml +++ b/sweep.yml @@ -4,15 +4,13 @@ metric: goal: minimize name: Overal_score parameters: - oracle_train_samples_num: - value: 10000 discriminator_complexity: - max: 1024 - min: 256 + max: 256 # must be multiplied by 4 + min: 64 distribution: int_uniform generator_complexity: - max: 1536 - min: 384 + max: 256 # must be multiplied by 4 + min: 64 distribution: int_uniform w2v_embedding_size: values: From c09da10a75430496f6e422da8fc11314678645aa Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 13:53:09 +0100 Subject: [PATCH 65/81] adding sweep.yml for parameters selection --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 9bb4a349..c75480cb 100644 --- a/main.py +++ b/main.py @@ -198,7 +198,7 @@ def program_config(parser): sweep_configuration = yaml.safe_load(sweep_yml) print('sweep_configuration', sweep_configuration) - # sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") + sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") sweep_id = "qdpnjvhf" print('sweep_id', sweep_id) From 542e8fc40ef3ea6e5cdb52a402dd5ed01c293d9d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 13:53:55 +0100 Subject: [PATCH 66/81] adding sweep.yml for parameters selection --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index c75480cb..7aae1b2d 100644 --- a/main.py +++ b/main.py @@ -199,7 +199,7 @@ def program_config(parser): print('sweep_configuration', sweep_configuration) sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") - sweep_id = "qdpnjvhf" + # sweep_id = "qdpnjvhf" print('sweep_id', sweep_id) def function_for_parameters_choice(): From 73edf152ec47db84c62b9c011301fc2418ae7439 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 15:30:06 +0100 Subject: [PATCH 67/81] adding sweep.yml for parameters selection --- instructor/real_data/fixem_instructor.py | 3 ++- instructor/real_data/instructor.py | 2 ++ utils/data_loader.py | 1 - 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index 1c6cb434..a08ac598 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -113,7 +113,7 @@ def discriminator_train_one_batch(self, real_vector, labels): def _run(self): for i in trange(cfg.max_epochs): - for labels, text_vector in self.train_data_supplier: + for labels, text_vector in tqdm(self.train_data_supplier, leave=False): if cfg.CUDA: labels, text_vector = labels.cuda(), text_vector.cuda() discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) @@ -124,6 +124,7 @@ def _run(self): generator_acc = self.generator_train_one_batch() if cfg.run_model == 'fixemgan': + print('calculating_metrics') scores = self.cal_metrics(fmt_str=True) if cfg.run_model == 'cat_fixemgan': scores = '\n\n'.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index ef7eb218..9db84af7 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -252,7 +252,9 @@ def cal_metrics(self, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) + print('all reset') metrics = {metric.name: metric.get_score() for metric in self.all_metrics} + print('get_score called') metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) wandb.log(metrics) diff --git a/utils/data_loader.py b/utils/data_loader.py index fd2c6c15..e2e9333a 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -188,7 +188,6 @@ def __iter__(self): else: yield self.labels[index - self.batch_size: index], self.vectorize_batch(self.tokenized[index - self.batch_size: index]) - def __len__(self): return self.batches_per_epoch From fca70b92034155846f9b06c26569e75d8cc74e32 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 15:35:20 +0100 Subject: [PATCH 68/81] adding sweep.yml for parameters selection --- instructor/real_data/fixem_instructor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index a08ac598..f823c32c 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import Dataset, DataLoader import torchtext -from tqdm import trange +from tqdm import trange, tqdm import config as cfg @@ -110,7 +110,6 @@ def discriminator_train_one_batch(self, real_vector, labels): ) return discriminator_acc - def _run(self): for i in trange(cfg.max_epochs): for labels, text_vector in tqdm(self.train_data_supplier, leave=False): @@ -131,7 +130,6 @@ def _run(self): self.log.info(f'epoch: {i}') self.log.info(f'{scores}') - def one_more_batch_for_generator( self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 ): From cea83c05b4962ce44fcc9887eb3f6daf483129f1 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 15:58:09 +0100 Subject: [PATCH 69/81] adding sweep.yml for parameters selection --- config.py | 2 +- run/run_fixem.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index ba348f1b..452bd8aa 100644 --- a/config.py +++ b/config.py @@ -73,7 +73,7 @@ temperature = 1 # ===Basic Train=== -samples_num = 10000 #, mr15: 2000, +samples_num = 1000 #, mr15: 2000, small_sample_num = 20 # used for self-blue MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 diff --git a/run/run_fixem.py b/run/run_fixem.py index b26cb7c2..7aa77de0 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -40,7 +40,7 @@ noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 -samples_num = 100_000 # training samples size +samples_num = 1000 # training samples size tips = '{} experiments' # ===Oracle or Real=== From 6d84e58b6a15b3b0248ea14ff0b41d758bad612a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 16:04:28 +0100 Subject: [PATCH 70/81] speeding up train process --- instructor/real_data/instructor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 9db84af7..6a6d9c75 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -243,6 +243,7 @@ def cal_metrics(self, fmt_str=False): with torch.no_grad(): # Prepare data for evaluation gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() + print('sampled') # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) From d1e52f994f699f35ac0419a98ad673ecc4bf902d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 18 Apr 2023 16:19:31 +0100 Subject: [PATCH 71/81] speeding up train process --- run/run_fixem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run/run_fixem.py b/run/run_fixem.py index 7aa77de0..8f0ba9f9 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -40,7 +40,7 @@ noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 -samples_num = 1000 # training samples size +samples_num = 100 # sample for metrics tips = '{} experiments' # ===Oracle or Real=== From ee55d29c576f650e0e1261daed5cac6a7c234ddc Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 20 Apr 2023 11:10:58 +0100 Subject: [PATCH 72/81] deleted unnecessary requirement --- utils/text_process.py | 102 +++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/utils/text_process.py b/utils/text_process.py index 83a3435c..9d6053d3 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -379,54 +379,54 @@ def vectorize_sentence(tokens, w2v, target_len: int = 52, embedding_size: int = return vectorized -if __name__ == '__main__': - os.chdir('../') - # process_cat_text() - # load_test_dict('mr15') - # extend_clas_train_data() - - # dataset preprocess and saving - import torchtext - import os - import nltk - nltk.download('punkt') - from tqdm.notebook import tqdm - from pathlib import Path - - def tokenize_and_save(source, path, filename): - with open(Path(path) / filename, 'w') as f: - for _, line in tqdm(source, desc=filename): - line = line.strip().lower() - line = ' '.join(nltk.tokenize.word_tokenize(line)) - line = ' '.join(line.split('\n')) - line = ' '.join(line.split('\\n')) - line = ' '.join(line.split('\\')) - f.write(line) - f.write('\n') - - AGNEWS_train, AGNEWS_test = torchtext.datasets.AG_NEWS( - root="./data", split=("train", "test") - ) - DBpedia_train, DBpedia_test = torchtext.datasets.DBpedia( - root="./data", split=("train", "test") - ) - WikiText103_train, WikiText103_valid, WikiText103_test = torchtext.datasets.WikiText103( - root="./data", split=("train", "valid", "test") - ) - YahooAnswers_train, YahooAnswers_test = torchtext.datasets.YahooAnswers( - root="./data", split=("train", "test") - ) - YelpReviewFull_train, YelpReviewFull_test = torchtext.datasets.YelpReviewFull( - root="./data", split=("train", "test") - ) - tokenize_and_save(AGNEWS_train, './dataset/', 'agnews_train.txt') - tokenize_and_save(AGNEWS_test, './dataset/testdata/', 'agnews_test.txt') - tokenize_and_save(DBpedia_train, './dataset/', 'dbpedia_train.txt') - tokenize_and_save(DBpedia_test, './dataset/testdata/', 'dbpedia_test.txt') - tokenize_and_save(enumerate(WikiText103_train), './dataset/', 'wikitext103_train.txt') - tokenize_and_save(enumerate(WikiText103_valid), './dataset/', 'wikitext103_valid.txt') - tokenize_and_save(enumerate(WikiText103_test), './dataset/testdata/', 'wikitext103_test.txt') - tokenize_and_save(YahooAnswers_train, './dataset/', 'yahooanswers_train.txt') - tokenize_and_save(YahooAnswers_test, './dataset/testdata/', 'yahooanswers_test.txt') - tokenize_and_save(YelpReviewFull_train, './dataset/', 'yelpreviewfull_train.txt') - tokenize_and_save(YelpReviewFull_test, './dataset/testdata/', 'yelpreviewfull_test.txt') +# if __name__ == '__main__': +# os.chdir('../') +# # process_cat_text() +# # load_test_dict('mr15') +# # extend_clas_train_data() + +# # dataset preprocess and saving +# import torchtext +# import os +# import nltk +# nltk.download('punkt') +# from tqdm.notebook import tqdm +# from pathlib import Path + +# def tokenize_and_save(source, path, filename): +# with open(Path(path) / filename, 'w') as f: +# for _, line in tqdm(source, desc=filename): +# line = line.strip().lower() +# line = ' '.join(nltk.tokenize.word_tokenize(line)) +# line = ' '.join(line.split('\n')) +# line = ' '.join(line.split('\\n')) +# line = ' '.join(line.split('\\')) +# f.write(line) +# f.write('\n') + +# AGNEWS_train, AGNEWS_test = torchtext.datasets.AG_NEWS( +# root="./data", split=("train", "test") +# ) +# DBpedia_train, DBpedia_test = torchtext.datasets.DBpedia( +# root="./data", split=("train", "test") +# ) +# WikiText103_train, WikiText103_valid, WikiText103_test = torchtext.datasets.WikiText103( +# root="./data", split=("train", "valid", "test") +# ) +# YahooAnswers_train, YahooAnswers_test = torchtext.datasets.YahooAnswers( +# root="./data", split=("train", "test") +# ) +# YelpReviewFull_train, YelpReviewFull_test = torchtext.datasets.YelpReviewFull( +# root="./data", split=("train", "test") +# ) +# tokenize_and_save(AGNEWS_train, './dataset/', 'agnews_train.txt') +# tokenize_and_save(AGNEWS_test, './dataset/testdata/', 'agnews_test.txt') +# tokenize_and_save(DBpedia_train, './dataset/', 'dbpedia_train.txt') +# tokenize_and_save(DBpedia_test, './dataset/testdata/', 'dbpedia_test.txt') +# tokenize_and_save(enumerate(WikiText103_train), './dataset/', 'wikitext103_train.txt') +# tokenize_and_save(enumerate(WikiText103_valid), './dataset/', 'wikitext103_valid.txt') +# tokenize_and_save(enumerate(WikiText103_test), './dataset/testdata/', 'wikitext103_test.txt') +# tokenize_and_save(YahooAnswers_train, './dataset/', 'yahooanswers_train.txt') +# tokenize_and_save(YahooAnswers_test, './dataset/testdata/', 'yahooanswers_test.txt') +# tokenize_and_save(YelpReviewFull_train, './dataset/', 'yelpreviewfull_train.txt') +# tokenize_and_save(YelpReviewFull_test, './dataset/testdata/', 'yelpreviewfull_test.txt') From 3835b2c48fa63232134251e0032a9a783b8295bf Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 20 Apr 2023 11:11:44 +0100 Subject: [PATCH 73/81] deleted unnecessary requirement --- instructor/oracle_data/fixem_instructor.py | 1 - instructor/real_data/fixem_instructor.py | 1 - 2 files changed, 2 deletions(-) diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index dba43bcc..6f9907ff 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -6,7 +6,6 @@ import numpy as np import torch from torch.utils.data import Dataset, DataLoader -import torchtext from tqdm import tqdm, trange diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index f823c32c..d37253c9 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -6,7 +6,6 @@ import numpy as np import torch from torch.utils.data import Dataset, DataLoader -import torchtext from tqdm import trange, tqdm From ac2622bbd46bac3806d5ad4ad919be01cb13594d Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 20 Apr 2023 11:29:51 +0100 Subject: [PATCH 74/81] adding new sweep parameters --- config.py | 9 ++++++++- run/run_fixem.py | 8 +++++++- sweep.yml | 12 ++++++++++++ utils/gan_loss.py | 10 +++++----- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/config.py b/config.py index 452bd8aa..0cf66023 100644 --- a/config.py +++ b/config.py @@ -46,6 +46,9 @@ noise_size = 1000 max_epochs = 20 target_len = 40 +real_fake_coeff = 1.0 +labels_coeff = 1.0 +diversity_coeff = 1.0 # ===Embedding=== w2v_embedding_size = 100 @@ -218,7 +221,8 @@ def init_param(opt): multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num + generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, \ + real_fake_coeff, labels_coeff, diversity_coeff if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -251,6 +255,9 @@ def init_param(opt): noise_size = opt.noise_size max_epochs = opt.max_epochs target_len = opt.target_len + real_fake_coeff = opt.real_fake_coeff + labels_coeff = opt.labels_coeff + diversity_coeff = opt.diversity_coeff samples_num = opt.samples_num vocab_size = opt.vocab_size diff --git a/run/run_fixem.py b/run/run_fixem.py index 8f0ba9f9..defa416a 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -36,7 +36,6 @@ run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan'] k_label = 2 CUDA = int(True) -batch_size = 32 noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 @@ -61,6 +60,9 @@ dis_init = 'uniform' batch_size = 64 target_len = [16, 20, 20, 16, 16, 40, 48, 20, 20] # architechture requires to be divisible by 4 +real_fake_coeff = 1.0 +labels_coeff = 1.0 +diversity_coeff = 1.0 # ===Generator=== generator_complexity = 768 #hyperparam @@ -104,6 +106,10 @@ '--batches_per_epoch', batches_per_epoch, '--noise_size', noise_size, '--target_len', target_len[job_id], + '--batch_size', batch_size, + '--real_fake_coeff', real_fake_coeff, + '--labels_coeff', labels_coeff, + '--diversity_coeff', diversity_coeff, # Generator '--generator_complexity', generator_complexity, diff --git a/sweep.yml b/sweep.yml index 3d77d567..8aadd2fc 100644 --- a/sweep.yml +++ b/sweep.yml @@ -17,3 +17,15 @@ parameters: - 128 - 256 - 512 + real_fake_coeff: + max: 4.0 + min: 0.1 + distribution: uniform + labels_coeff: + max: 4.0 + min: 0.0 + distribution: uniform + diversity_coeff: + max: 4.0 + min: 0.0 + distribution: uniform diff --git a/utils/gan_loss.py b/utils/gan_loss.py index bb373839..e799c17d 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -145,9 +145,9 @@ def D_loss(self, Dreal, Dfake): def G_loss_fixem(self, real_fake_predicts, label_predicts, target_labels, fakes): target_fake = self.get_target_tensor(real_fake_predicts, target_is_real=True) - real_fake_loss = self.real_fake_criterion(real_fake_predicts, target_fake) - labels_loss = self.label_criterion(label_predicts, target_labels) - diversity_loss = self.diversity_criterion(fakes) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion(real_fake_predicts, target_fake) + labels_loss = cfg.labels_coeff * self.label_criterion(label_predicts, target_labels) + diversity_loss = cfg.diversity_coeff * self.diversity_criterion(fakes) loss = real_fake_loss + diversity_loss loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss return loss @@ -156,8 +156,8 @@ def D_loss_fixem(self, real_fake_predicts, label_predicts, target_labels): target_real = self.get_target_tensor(real_fake_predicts.chunk(2)[0], target_is_real=True) target_fake = self.get_target_tensor(real_fake_predicts.chunk(2)[1], target_is_real=False) target_real_fake = torch.cat((target_real, target_fake)) - real_fake_loss = self.real_fake_criterion(real_fake_predicts, target_real_fake) - labels_loss = self.label_criterion(label_predicts, target_labels) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion(real_fake_predicts, target_real_fake) + labels_loss = cfg.labels_coeff * self.label_criterion(label_predicts, target_labels) loss = real_fake_loss loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss return loss From 734abc9eccfd0d98e3378e0b9a5868779ed8ff5c Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 20 Apr 2023 11:32:24 +0100 Subject: [PATCH 75/81] adding new sweep parameters --- main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main.py b/main.py index 7aae1b2d..869e662d 100644 --- a/main.py +++ b/main.py @@ -124,6 +124,11 @@ def program_config(parser): parser.add_argument('--signal_file', default=cfg.signal_file, type=str) parser.add_argument('--tips', default=cfg.tips, type=str) + # Loss coefficients + parser.add_argument('--real_fake_coeff', default=1.0, type=float) + parser.add_argument('--labels_coeff', default=1.0, type=float) + parser.add_argument('--diversity_coeff', default=1.0, type=float) + return parser From c5826b41943a2f80cc53faf5699853f213430560 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 27 Apr 2023 10:13:14 +0100 Subject: [PATCH 76/81] launching last sweep with smaller amount of parameters --- main.py | 19 +++++++++---------- sweep.yml | 14 ++++++-------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 869e662d..81b59ba9 100644 --- a/main.py +++ b/main.py @@ -128,7 +128,6 @@ def program_config(parser): parser.add_argument('--real_fake_coeff', default=1.0, type=float) parser.add_argument('--labels_coeff', default=1.0, type=float) parser.add_argument('--diversity_coeff', default=1.0, type=float) - return parser @@ -207,7 +206,12 @@ def program_config(parser): # sweep_id = "qdpnjvhf" print('sweep_id', sweep_id) - def function_for_parameters_choice(): + + def full_train_run(opt): + inst = instruction_dict[cfg.run_model](opt) + inst._run() + + def function_for_parameters_sweep(): run = wandb.init() # Initialize a new wandb run config = run.config # Get the config dictionary for the current run print('config', config) @@ -215,13 +219,8 @@ def function_for_parameters_choice(): # Update 'opt' with the hyperparameters from 'config' for name, value in config.items(): setattr(opt, name, value) - - inst = instruction_dict[cfg.run_model](opt) - if not cfg.if_test: - inst._run() - else: - inst._test() - + full_train_run(opt) run.finish() # Make sure to finish the run - wandb.agent(sweep_id=sweep_id, function=function_for_parameters_choice) + + # wandb.agent(sweep_id=sweep_id, function=function_for_parameters_sweep) diff --git a/sweep.yml b/sweep.yml index 8aadd2fc..19af8714 100644 --- a/sweep.yml +++ b/sweep.yml @@ -14,18 +14,16 @@ parameters: distribution: int_uniform w2v_embedding_size: values: - - 128 - - 256 - 512 real_fake_coeff: - max: 4.0 - min: 0.1 + max: 1.0 + min: 1.0 distribution: uniform labels_coeff: - max: 4.0 - min: 0.0 + max: 2.5 + min: 2.0 distribution: uniform diversity_coeff: - max: 4.0 - min: 0.0 + max: 2.5 + min: 2.0 distribution: uniform From 61f242e7834aec692f52c712e4ad9b7a9bff23ca Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 27 Apr 2023 10:21:21 +0100 Subject: [PATCH 77/81] setting correct min max --- sweep.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sweep.yml b/sweep.yml index 19af8714..4738dd30 100644 --- a/sweep.yml +++ b/sweep.yml @@ -16,7 +16,7 @@ parameters: values: - 512 real_fake_coeff: - max: 1.0 + max: 1.5 min: 1.0 distribution: uniform labels_coeff: From f1149e99a4effcd21a7bcf8b688f306426f97835 Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 27 Apr 2023 10:24:32 +0100 Subject: [PATCH 78/81] correcting launch script --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 81b59ba9..fe55214e 100644 --- a/main.py +++ b/main.py @@ -202,8 +202,8 @@ def program_config(parser): sweep_configuration = yaml.safe_load(sweep_yml) print('sweep_configuration', sweep_configuration) - sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") - # sweep_id = "qdpnjvhf" + # sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") + sweep_id = "7g6po2bd" print('sweep_id', sweep_id) @@ -223,4 +223,4 @@ def function_for_parameters_sweep(): run.finish() # Make sure to finish the run - # wandb.agent(sweep_id=sweep_id, function=function_for_parameters_sweep) + wandb.agent(sweep_id=sweep_id, function=function_for_parameters_sweep) From fca2696a295cdc608cf9b46e147dfdd7b57bb5ce Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Thu, 27 Apr 2023 10:40:44 +0100 Subject: [PATCH 79/81] generating sweep_id --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index fe55214e..5534b287 100644 --- a/main.py +++ b/main.py @@ -202,8 +202,8 @@ def program_config(parser): sweep_configuration = yaml.safe_load(sweep_yml) print('sweep_configuration', sweep_configuration) - # sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") - sweep_id = "7g6po2bd" + sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") + # sweep_id = "7g6po2bd" print('sweep_id', sweep_id) From b9c2fa2d3988a10df90df53e94fc821855f68fcc Mon Sep 17 00:00:00 2001 From: salaxieb Date: Thu, 6 Jul 2023 13:56:04 +0100 Subject: [PATCH 80/81] run black on whole repo --- config.py | 261 ++++++----- instructor/oracle_data/catgan_instructor.py | 424 +++++++++++++----- instructor/oracle_data/cot_instructor.py | 62 ++- instructor/oracle_data/dgsan_instructor.py | 88 +++- instructor/oracle_data/dpgan_instructor.py | 93 ++-- instructor/oracle_data/evogan_instructor.py | 283 +++++++++--- instructor/oracle_data/fixem_instructor.py | 15 +- instructor/oracle_data/instructor.py | 189 +++++--- instructor/oracle_data/jsdgan_instructor.py | 61 ++- instructor/oracle_data/leakgan_instructor.py | 141 ++++-- instructor/oracle_data/maligan_instructor.py | 87 ++-- instructor/oracle_data/relgan_instructor.py | 80 +++- instructor/oracle_data/sentigan_instructor.py | 176 ++++++-- instructor/oracle_data/seqgan_instructor.py | 87 ++-- instructor/real_data/catgan_instructor.py | 358 +++++++++++---- instructor/real_data/cot_instructor.py | 62 ++- instructor/real_data/dgsan_instructor.py | 82 +++- instructor/real_data/dpgan_instructor.py | 97 ++-- instructor/real_data/evogan_instructor.py | 257 ++++++++--- instructor/real_data/fixem_instructor.py | 103 +++-- instructor/real_data/instructor.py | 217 ++++++--- instructor/real_data/jsdgan_instructor.py | 55 ++- instructor/real_data/leakgan_instructor.py | 133 ++++-- instructor/real_data/maligan_instructor.py | 83 ++-- instructor/real_data/relgan_instructor.py | 78 +++- instructor/real_data/sentigan_instructor.py | 198 +++++--- instructor/real_data/seqgan_instructor.py | 83 ++-- main.py | 236 +++++----- metrics/bleu.py | 46 +- metrics/clas_acc.py | 4 +- metrics/dummy.py | 3 +- metrics/gpt_nll.py | 24 +- metrics/ioc.py | 15 +- metrics/nll.py | 19 +- metrics/ppl.py | 57 ++- models/Oracle.py | 11 +- models/discriminators/CatGAN_D.py | 81 +++- models/discriminators/CoT_D.py | 8 +- models/discriminators/DPGAN_D.py | 14 +- models/discriminators/EvoGAN_D.py | 55 ++- models/discriminators/FixemGAN_D.py | 16 +- models/discriminators/LeakGAN_D.py | 11 +- models/discriminators/MaliGAN_D.py | 11 +- models/discriminators/RelGAN_D.py | 55 ++- models/discriminators/SentiGAN_D.py | 41 +- models/discriminators/SeqGAN_D.py | 11 +- models/discriminators/discriminator.py | 126 ++++-- models/generators/CatGAN_G.py | 93 ++-- models/generators/CoT_G.py | 14 +- models/generators/DGSAN_G.py | 10 +- models/generators/DPGAN_G.py | 14 +- models/generators/EvoGAN_G.py | 51 ++- models/generators/FixemGAN_G.py | 36 +- models/generators/JSDGAN_G.py | 43 +- models/generators/LeakGAN_G.py | 212 ++++++--- models/generators/MaliGAN_G.py | 18 +- models/generators/RelGAN_G.py | 43 +- models/generators/SentiGAN_G.py | 22 +- models/generators/SeqGAN_G.py | 18 +- models/generators/generator.py | 31 +- models/relational_rnn_general.py | 74 ++- run/run_catgan.py | 198 +++++--- run/run_cot.py | 126 ++++-- run/run_dgsan.py | 114 +++-- run/run_dpgan.py | 150 ++++--- run/run_fixem.py | 181 +++++--- run/run_jsdgan.py | 114 +++-- run/run_leakgan.py | 156 ++++--- run/run_maligan.py | 147 +++--- run/run_relgan.py | 168 ++++--- run/run_sentigan.py | 156 ++++--- run/run_seqgan.py | 147 +++--- utils/cat_data_loader.py | 41 +- utils/create_embeddings.py | 2 +- utils/data_loader.py | 82 ++-- utils/data_utils.py | 133 ++++-- utils/gan_loss.py | 116 +++-- utils/helpers.py | 80 ++-- utils/nn_helpers.py | 11 +- utils/rollout.py | 38 +- utils/text_process.py | 186 ++++---- utils/visualization.py | 76 ++-- visual/visual_human.py | 130 ++++-- visual/visual_metric.py | 38 +- visual/visual_temp_appendix.py | 59 +-- visual/visual_temp_compare.py | 59 +-- 86 files changed, 5522 insertions(+), 2562 deletions(-) diff --git a/config.py b/config.py index 0cf66023..86052e5c 100644 --- a/config.py +++ b/config.py @@ -24,10 +24,10 @@ dis_pretrain = False clas_pretrain = False -run_model = 'catgan' # seqgan, leakgan, maligan, jsdgan, relgan, evogan, sentigan, catgan, dpgan, dgsan, cot +run_model = "catgan" # seqgan, leakgan, maligan, jsdgan, relgan, evogan, sentigan, catgan, dpgan, dgsan, cot k_label = 2 # num of labels, >=2 -gen_init = 'truncated_normal' # normal, uniform, truncated_normal -dis_init = 'uniform' # normal, uniform, truncated_normal +gen_init = "truncated_normal" # normal, uniform, truncated_normal +dis_init = "uniform" # normal, uniform, truncated_normal # ===CatGAN=== n_parent = 1 @@ -59,33 +59,37 @@ # ===Oracle or Real, type=== if_real_data = False # if use real data -dataset = 'oracle' # oracle, image_coco, emnlp_news, amazon_app_book, amazon_app_movie, mr15 -model_type = 'vanilla' # vanilla, RMC (custom) -loss_type = 'rsgan' # rsgan lsgan ragan vanilla wgan hinge, for Discriminator (CatGAN) -mu_type = 'ragan' # rsgan lsgan ragan vanilla wgan hinge -eval_type = 'Ra' # standard, rsgan, nll, nll-f1, Ra, bleu3, bleu-f1 -d_type = 'Ra' # S (Standard), Ra (Relativistic_average) -vocab_size = 5000 # oracle: 5000, coco: 4683, emnlp: 5256, amazon_app_book: 6418, mr15: 6289 +dataset = ( + "oracle" # oracle, image_coco, emnlp_news, amazon_app_book, amazon_app_movie, mr15 +) +model_type = "vanilla" # vanilla, RMC (custom) +loss_type = "rsgan" # rsgan lsgan ragan vanilla wgan hinge, for Discriminator (CatGAN) +mu_type = "ragan" # rsgan lsgan ragan vanilla wgan hinge +eval_type = "Ra" # standard, rsgan, nll, nll-f1, Ra, bleu3, bleu-f1 +d_type = "Ra" # S (Standard), Ra (Relativistic_average) +vocab_size = ( + 5000 # oracle: 5000, coco: 4683, emnlp: 5256, amazon_app_book: 6418, mr15: 6289 +) max_seq_len = 20 # oracle: 20, coco: 37, emnlp: 51, amazon_app_book: 40 ADV_train_epoch = 2000 # SeqGAN, LeakGAN-200, RelGAN-3000 extend_vocab_size = 0 # plus test data, only used for Classifier -temp_adpt = 'exp' # no, lin, exp, log, sigmoid, quad, sqrt -mu_temp = 'exp' # lin exp log sigmoid quad sqrt +temp_adpt = "exp" # no, lin, exp, log, sigmoid, quad, sqrt +mu_temp = "exp" # lin exp log sigmoid quad sqrt evo_temp_step = 1 temperature = 1 # ===Basic Train=== -samples_num = 1000 #, mr15: 2000, -small_sample_num = 20 # used for self-blue +samples_num = 1000 # , mr15: 2000, +small_sample_num = 20 # used for self-blue MLE_train_epoch = 150 # SeqGAN-80, LeakGAN-8, RelGAN-150 PRE_clas_epoch = 10 inter_epoch = 15 # LeakGAN-10 batch_size = 64 # 64 start_letter = 1 padding_idx = 0 -start_token = 'BOS' -padding_token = 'EOS' +start_token = "BOS" +padding_token = "EOS" gen_lr = 0.01 # 0.01 gen_adv_lr = 1e-4 # RelGAN-1e-4 dis_lr = 1e-4 # SeqGAN,LeakGAN-1e-2, RelGAN-1e-4 @@ -95,10 +99,10 @@ pre_log_step = 10 adv_log_step = 20 -train_data = 'dataset/' + dataset + '.txt' -test_data = 'dataset/testdata/' + dataset + '_test.txt' -cat_train_data = 'dataset/' + dataset + '_cat{}.txt' -cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' +train_data = "dataset/" + dataset + ".txt" +test_data = "dataset/testdata/" + dataset + "_test.txt" +cat_train_data = "dataset/" + dataset + "_cat{}.txt" +cat_test_data = "dataset/testdata/" + dataset + "_cat{}_test.txt" # ===Metrics=== use_nll_oracle = True @@ -137,21 +141,21 @@ # ===log=== log_time_str = strftime("%m%d_%H%M_%S", localtime()) log_filename = strftime("log/log_%s" % log_time_str) -if os.path.exists(log_filename + '.txt'): +if os.path.exists(log_filename + ".txt"): i = 2 while True: - if not os.path.exists(log_filename + '_%d' % i + '.txt'): - log_filename = log_filename + '_%d' % i + if not os.path.exists(log_filename + "_%d" % i + ".txt"): + log_filename = log_filename + "_%d" % i break i += 1 -log_filename = log_filename + '.txt' +log_filename = log_filename + ".txt" # Automatically choose GPU or CPU if torch.cuda.is_available() and torch.cuda.device_count() > 0: - os.system('nvidia-smi -q -d Utilization > gpu') - with open('gpu', 'r') as _tmpfile: - util_gpu = list(map(int, re.findall(r'Gpu\s+:\s*(\d+)\s*%', _tmpfile.read()))) - os.remove('gpu') + os.system("nvidia-smi -q -d Utilization > gpu") + with open("gpu", "r") as _tmpfile: + util_gpu = list(map(int, re.findall(r"Gpu\s+:\s*(\d+)\s*%", _tmpfile.read()))) + os.remove("gpu") if len(util_gpu): device = util_gpu.index(min(util_gpu)) else: @@ -162,67 +166,62 @@ # print('device: ', device) if multi_gpu: - devices = '0,1' - devices = list(map(int, devices.split(','))) + devices = "0,1" + devices = list(map(int, devices.split(","))) device = devices[0] torch.cuda.set_device(device) - os.environ['CUDA_VISIBLE_DIVICES'] = ','.join(map(str, devices)) + os.environ["CUDA_VISIBLE_DIVICES"] = ",".join(map(str, devices)) else: devices = str(device) torch.cuda.set_device(device) # ===Save Model and samples=== -save_root = 'save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/'.format(time.strftime("%Y%m%d"), - dataset, run_model, model_type, - d_type, - loss_type, - '+'.join( - [m[:2] for m in - mu_type.split()]), - eval_type, max_seq_len, - temperature, lambda_fd, - log_time_str) -save_samples_root = save_root + 'samples/' -save_model_root = save_root + 'models/' - -oracle_state_dict_path = 'pretrain/oracle_data/oracle_lstm.pt' -oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}.pt' -multi_oracle_state_dict_path = 'pretrain/oracle_data/oracle{}_lstm.pt' -multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}.pt' - -pretrain_root = 'pretrain/{}/'.format(dataset if if_real_data else 'oracle_data') -pretrained_gen_path = pretrain_root + 'gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) -pretrained_dis_path = pretrain_root + 'dis_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) -pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) - -embedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' -pretrain_embedding_path = embedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) -texts_pile = 'dataset/' # do not include testdata - -signal_file = 'run_signal.txt' - -tips = '' +save_root = "save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/".format( + time.strftime("%Y%m%d"), + dataset, + run_model, + model_type, + d_type, + loss_type, + "+".join([m[:2] for m in mu_type.split()]), + eval_type, + max_seq_len, + temperature, + lambda_fd, + log_time_str, +) +save_samples_root = save_root + "samples/" +save_model_root = save_root + "models/" + +oracle_state_dict_path = "pretrain/oracle_data/oracle_lstm.pt" +oracle_samples_path = "pretrain/oracle_data/oracle_lstm_samples_{}.pt" +multi_oracle_state_dict_path = "pretrain/oracle_data/oracle{}_lstm.pt" +multi_oracle_samples_path = "pretrain/oracle_data/oracle{}_lstm_samples_{}.pt" + +pretrain_root = "pretrain/{}/".format(dataset if if_real_data else "oracle_data") +pretrained_gen_path = pretrain_root + "gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) +pretrained_dis_path = pretrain_root + "dis_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) +pretrained_clas_path = pretrain_root + "clas_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num +) + +embedding_root = "pretrain/real_data/" if if_real_data else "pretrain/oracle_data/" +pretrain_embedding_path = embedding_root + "w2v_embedding_size_{}.model".format( + w2v_embedding_size +) +texts_pile = "dataset/" # do not include testdata + +signal_file = "run_signal.txt" + +tips = "" # Init settings according to parser def init_param(opt): - global run_model, model_type, loss_type, CUDA, device, data_shuffle, samples_num, vocab_size, \ - MLE_train_epoch, ADV_train_epoch, inter_epoch, batch_size, max_seq_len, start_letter, padding_idx, \ - gen_lr, gen_adv_lr, dis_lr, clip_norm, pre_log_step, adv_log_step, train_data, test_data, temp_adpt, \ - temperature, oracle_pretrain, gen_pretrain, dis_pretrain, ADV_g_step, rollout_num, gen_embed_dim, \ - gen_hidden_dim, goal_size, step_size, mem_slots, num_heads, head_size, d_step, d_epoch, \ - ADV_d_step, ADV_d_epoch, dis_embed_dim, dis_hidden_dim, num_rep, log_filename, save_root, \ - signal_file, tips, save_samples_root, save_model_root, if_real_data, pretrained_gen_path, \ - pretrained_dis_path, pretrain_root, if_test, dataset, PRE_clas_epoch, oracle_samples_path, \ - pretrained_clas_path, n_parent, mu_type, eval_type, d_type, eval_b_num, lambda_fd, d_out_mean, \ - lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, \ - multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, \ - use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, \ - w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, \ - generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, \ - real_fake_coeff, labels_coeff, diversity_coeff + global run_model, model_type, loss_type, CUDA, device, data_shuffle, samples_num, vocab_size, MLE_train_epoch, ADV_train_epoch, inter_epoch, batch_size, max_seq_len, start_letter, padding_idx, gen_lr, gen_adv_lr, dis_lr, clip_norm, pre_log_step, adv_log_step, train_data, test_data, temp_adpt, temperature, oracle_pretrain, gen_pretrain, dis_pretrain, ADV_g_step, rollout_num, gen_embed_dim, gen_hidden_dim, goal_size, step_size, mem_slots, num_heads, head_size, d_step, d_epoch, ADV_d_step, ADV_d_epoch, dis_embed_dim, dis_hidden_dim, num_rep, log_filename, save_root, signal_file, tips, save_samples_root, save_model_root, if_real_data, pretrained_gen_path, pretrained_dis_path, pretrain_root, if_test, dataset, PRE_clas_epoch, oracle_samples_path, pretrained_clas_path, n_parent, mu_type, eval_type, d_type, eval_b_num, lambda_fd, d_out_mean, lambda_fq, freeze_dis, freeze_clas, use_all_real_fake, use_population, gen_init, dis_init, multi_oracle_samples_path, k_label, cat_train_data, cat_test_data, evo_temp_step, devices, use_nll_oracle, use_nll_gen, use_nll_div, use_bleu, use_self_bleu, use_clas_acc, use_ppl, w2v_embedding_size, w2v_window, w2v_min_count, w2v_workers, pretrain_embedding_path, batches_per_epoch, generator_complexity, discriminator_complexity, noise_size, max_epochs, target_len, w2v_samples_num, real_fake_coeff, labels_coeff, diversity_coeff if_test = True if opt.if_test == 1 else False run_model = opt.run_model @@ -323,55 +322,85 @@ def init_param(opt): # CUDA device if multi_gpu: if type(devices) == str: - devices = list(map(int, devices.split(','))) + devices = list(map(int, devices.split(","))) device = devices[0] torch.cuda.set_device(device) - os.environ['CUDA_VISIBLE_DIVICES'] = ','.join(map(str, devices)) + os.environ["CUDA_VISIBLE_DIVICES"] = ",".join(map(str, devices)) else: devices = str(device) torch.cuda.set_device(device) # Save path - save_root = 'save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/'.format(time.strftime("%Y%m%d"), - dataset, run_model, model_type, - d_type, - loss_type, - '+'.join( - [m[:2] for m in - mu_type.split()]), - eval_type, max_seq_len, - temperature, lambda_fd, - log_time_str) - - save_samples_root = save_root + 'samples/' - save_model_root = save_root + 'models/' - - train_data = 'dataset/' + dataset + '.txt' if if_real_data else 'pretrain/oracle_data/' + dataset + '.txt' - test_data = 'dataset/testdata/' + dataset + '_test.txt' - cat_train_data = 'dataset/' + dataset + '_cat{}.txt' if if_real_data else 'pretrain/oracle_data/' + dataset + '_cat{}.txt' - cat_test_data = 'dataset/testdata/' + dataset + '_cat{}_test.txt' + save_root = ( + "save/{}/{}/{}_{}_dt-{}_lt-{}_mt-{}_et-{}_sl{}_temp{}_lfd{}_T{}/".format( + time.strftime("%Y%m%d"), + dataset, + run_model, + model_type, + d_type, + loss_type, + "+".join([m[:2] for m in mu_type.split()]), + eval_type, + max_seq_len, + temperature, + lambda_fd, + log_time_str, + ) + ) + + save_samples_root = save_root + "samples/" + save_model_root = save_root + "models/" + + train_data = ( + "dataset/" + dataset + ".txt" + if if_real_data + else "pretrain/oracle_data/" + dataset + ".txt" + ) + test_data = "dataset/testdata/" + dataset + "_test.txt" + cat_train_data = ( + "dataset/" + dataset + "_cat{}.txt" + if if_real_data + else "pretrain/oracle_data/" + dataset + "_cat{}.txt" + ) + cat_test_data = "dataset/testdata/" + dataset + "_cat{}_test.txt" if max_seq_len == 40: - oracle_samples_path = 'pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt' - multi_oracle_samples_path = 'pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt' - - pretrain_root = 'pretrain/{}/'.format(dataset if if_real_data else 'oracle_data') - pretrained_gen_path = pretrain_root + 'gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, - max_seq_len, samples_num) - pretrained_dis_path = pretrain_root + 'dis_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) - pretrained_clas_path = pretrain_root + 'clas_pretrain_{}_{}_sl{}_sn{}.pt'.format(run_model, model_type, max_seq_len, - samples_num) - embedding_root = 'pretrain/real_data/' if if_real_data else 'pretrain/oracle_data/' - pretrain_embedding_path = embedding_root + 'w2v_embedding_size_{}.model'.format(w2v_embedding_size) + oracle_samples_path = "pretrain/oracle_data/oracle_lstm_samples_{}_sl40.pt" + multi_oracle_samples_path = ( + "pretrain/oracle_data/oracle{}_lstm_samples_{}_sl40.pt" + ) + + pretrain_root = "pretrain/{}/".format(dataset if if_real_data else "oracle_data") + pretrained_gen_path = pretrain_root + "gen_MLE_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + pretrained_dis_path = pretrain_root + "dis_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + pretrained_clas_path = pretrain_root + "clas_pretrain_{}_{}_sl{}_sn{}.pt".format( + run_model, model_type, max_seq_len, samples_num + ) + embedding_root = "pretrain/real_data/" if if_real_data else "pretrain/oracle_data/" + pretrain_embedding_path = embedding_root + "w2v_embedding_size_{}.model".format( + w2v_embedding_size + ) # Assertion - assert k_label >= 2, 'Error: k_label = {}, which should be >=2!'.format(k_label) - assert eval_b_num >= n_parent * ADV_d_step, 'Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!'.format( - eval_b_num, n_parent * ADV_d_step) + assert k_label >= 2, "Error: k_label = {}, which should be >=2!".format(k_label) + assert ( + eval_b_num >= n_parent * ADV_d_step + ), "Error: eval_b_num = {}, which should be >= n_parent * ADV_d_step ({})!".format( + eval_b_num, n_parent * ADV_d_step + ) # Create Directory - dir_list = ['save', 'savefig', 'log', 'pretrain', 'dataset', - 'pretrain/{}'.format(dataset if if_real_data else 'oracle_data')] + dir_list = [ + "save", + "savefig", + "log", + "pretrain", + "dataset", + "pretrain/{}".format(dataset if if_real_data else "oracle_data"), + ] if not if_test: dir_list.extend([save_root, save_samples_root, save_model_root]) for d in dir_list: diff --git a/instructor/oracle_data/catgan_instructor.py b/instructor/oracle_data/catgan_instructor.py index 04d051b9..2c960429 100644 --- a/instructor/oracle_data/catgan_instructor.py +++ b/instructor/oracle_data/catgan_instructor.py @@ -31,22 +31,57 @@ class CatGANInstructor(BasicInstructor): - def __init__(self, opt): super(CatGANInstructor, self).__init__(opt) # generator, discriminator - self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - - self.gen = CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, - gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = CatGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.oracle_list = [ + Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + + self.gen = CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = CatGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -54,17 +89,24 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) # DataLoader - self.all_oracle_data = CatGenDataIter(self.oracle_samples_list) # Shuffled all oracle data + self.all_oracle_data = CatGenDataIter( + self.oracle_samples_list + ) # Shuffled all oracle data def init_model(self): if cfg.oracle_pretrain: @@ -72,12 +114,20 @@ def init_model(self): oracle_path = cfg.multi_oracle_state_dict_path.format(i) if not os.path.exists(oracle_path): create_multi_oracle(cfg.k_label) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:%d' % cfg.device)) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:%d" % cfg.device) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -97,14 +147,24 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===Pre-train Generator=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===Adv-train=== progress = tqdm(range(cfg.ADV_train_epoch)) @@ -112,27 +172,45 @@ def _run(self): if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # ===Test=== - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.comb_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: for label_i in range(cfg.k_label): - self._save('ADV', adv_epoch, label_i) + self._save("ADV", adv_epoch, label_i) def _test(self): - self.log.debug('>>> Begin test...') + self.log.debug(">>> Begin test...") self._run() pass @@ -143,17 +221,20 @@ def pretrain_generator(self, epochs): """ for epoch in range(epochs): # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.all_oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.all_oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if not cfg.if_test and cfg.if_save: for label_i in range(cfg.k_label): - self._save('MLE', epoch, label_i) + self._save("MLE", epoch, label_i) def evolve_generator(self, evo_g_step): # evaluation real data @@ -169,13 +250,21 @@ def evolve_generator(self, evo_g_step): # all child share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -200,7 +289,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -224,17 +315,25 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -250,8 +349,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -303,14 +404,22 @@ def evolve_generator_population(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label # evaluate all parents - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): self.load_gen(parent, parent_opt) self.prepare_eval_fake_data() Fq, Fd, score = self.evaluation(cfg.eval_type) @@ -324,7 +433,9 @@ def evolve_generator_population(self, evo_g_step): # randomly choose a parent, variation target_idx = random.randint(0, len(self.parents) - 1) for j, criterionG in enumerate(self.G_criterion): - self.load_gen(self.parents[target_idx], self.parent_adv_opts[target_idx]) # load generator + self.load_gen( + self.parents[target_idx], self.parent_adv_opts[target_idx] + ) # load generator # Variation self.variation(evo_g_step, criterionG) @@ -340,7 +451,9 @@ def evolve_generator_population(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation.append(criterionG.loss_mode) @@ -353,15 +466,21 @@ def evolve_discriminator(self, evo_d_step): global dc_loss, dd_loss, d_loss total_loss = [] - all_gen_samples_list = list(map(self.merge, *self.best_fake_samples)) # merge each label of data - self.all_gen_samples_list = self.shuffle_eval_samples(all_gen_samples_list) # shuffle data + all_gen_samples_list = list( + map(self.merge, *self.best_fake_samples) + ) # merge each label of data + self.all_gen_samples_list = self.shuffle_eval_samples( + all_gen_samples_list + ) # shuffle data for step in range(evo_d_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('D', step) + dis_real_samples, dis_gen_samples = self.prepare_train_data("D", step) d_loss = 0 all_d_out_real = [] all_d_out_fake = [] - for (real_samples, fake_samples) in zip(dis_real_samples, dis_gen_samples): # for each label samples + for (real_samples, fake_samples) in zip( + dis_real_samples, dis_gen_samples + ): # for each label samples d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) d_loss += self.D_criterion(d_out_real, d_out_fake) @@ -386,14 +505,16 @@ def variation(self, g_step, criterionG): """Optimize one child (Generator)""" total_loss = [] for step in range(g_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('G') + dis_real_samples, dis_gen_samples = self.prepare_train_data("G") # ===Train=== g_loss = 0 all_d_out_real = [] all_d_out_fake = [] # for i, (real_samples, fake_samples) in enumerate(zip(dis_real_samples, dis_gen_samples)): - for i, (d_out_real, fake_samples) in enumerate(zip(self.d_out_real, dis_gen_samples)): # share real + for i, (d_out_real, fake_samples) in enumerate( + zip(self.d_out_real, dis_gen_samples) + ): # share real # d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) g_loss += criterionG(d_out_real, d_out_fake) @@ -416,50 +537,78 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) - for i in range(cfg.k_label)] + eval_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i + ) + for i in range(cfg.k_label) + ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) - nll_div.append(NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) - if 'f1' in eval_type: + nll_div.append( + NLL.cal_nll_with_label( + self.gen, gen_data.loader, label_i, self.mle_criterion + ) + ) + if "f1" in eval_type: if cfg.k_label == 1: Fd = nll_div[0] if len(nll_div) > 0 else 0 elif cfg.k_label == 2: - Fd = nll_div[0] * nll_div[1] / (nll_div[0] + nll_div[1]) if len(nll_div) > 0 else 0 + Fd = ( + nll_div[0] * nll_div[1] / (nll_div[0] + nll_div[1]) + if len(nll_div) > 0 + else 0 + ) else: - raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) + raise NotImplementedError( + "k_label = %d is not supported" % cfg.k_label + ) else: Fd = sum(nll_div) else: Fd = 0 # Fq - if 'nll' in eval_type: + if "nll" in eval_type: nll_oracle = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) if cfg.lambda_fq != 0: - nll_oracle.append(-NLL.cal_nll_with_label(self.oracle_list[label_i], gen_data.loader, label_i, - self.mle_criterion)) - - if 'f1' in eval_type: + nll_oracle.append( + -NLL.cal_nll_with_label( + self.oracle_list[label_i], + gen_data.loader, + label_i, + self.mle_criterion, + ) + ) + + if "f1" in eval_type: if cfg.k_label == 1: Fq = nll_oracle[0] if len(nll_oracle) > 0 else 0 elif cfg.k_label == 2: - Fq = nll_oracle[0] * nll_oracle[1] / (nll_oracle[0] + nll_oracle[1]) if len(nll_oracle) > 0 else 0 + Fq = ( + nll_oracle[0] * nll_oracle[1] / (nll_oracle[0] + nll_oracle[1]) + if len(nll_oracle) > 0 + else 0 + ) else: - raise NotImplementedError("k_label = %d is not supported" % cfg.k_label) + raise NotImplementedError( + "k_label = %d is not supported" % cfg.k_label + ) else: # sum Fq = sum(nll_oracle) - elif eval_type == 'Ra': + elif eval_type == "Ra": g_loss = 0 for i in range(cfg.k_label): - g_loss += torch.sigmoid(self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() + g_loss += torch.sigmoid( + self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i]) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -470,7 +619,7 @@ def evaluation(self, eval_type): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target, label = data['input'], data['target'], data['label'] + inp, target, label = data["input"], data["target"], data["label"] if cfg.CUDA: inp, target, label = inp.cuda(), target.cuda(), label.cuda() @@ -483,8 +632,13 @@ def train_gen_epoch(self, model, data_loader, criterion, optimizer): def _save(self, phase, epoch, label_i=None): assert type(label_i) == int - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_c{}_{}_{:05d}.txt'.format(label_i, phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_c{}_{}_{:05d}.txt".format( + label_i, phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, label_i=label_i) write_tensor(save_sample_path, samples) @@ -495,50 +649,86 @@ def merge(*args): def shuffle_eval_samples(self, all_eval_samples): temp = [] for i in range(cfg.k_label): - temp.append(all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))]) + temp.append( + all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))] + ) return temp def prepare_train_data(self, which, step=None): """Prepare train data for both Generator and Discriminator, each samples_list contains k_label batches of data""" - assert which == 'D' or which == 'G', 'only support for D and G!!' + assert which == "D" or which == "G", "only support for D and G!!" real_samples_list = [ - F.one_hot(self.oracle_data_list[i].random_batch()['target'][:cfg.batch_size], - cfg.vocab_size).float().cuda() - for i in range(cfg.k_label)] - if which == 'D': - assert step is not None, 'missing step' - gen_samples_list = [self.all_gen_samples_list[i][step * cfg.batch_size:(step + 1) * cfg.batch_size] - for i in range(cfg.k_label)] # get a batch from each label + F.one_hot( + self.oracle_data_list[i].random_batch()["target"][: cfg.batch_size], + cfg.vocab_size, + ) + .float() + .cuda() + for i in range(cfg.k_label) + ] + if which == "D": + assert step is not None, "missing step" + gen_samples_list = [ + self.all_gen_samples_list[i][ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] + for i in range(cfg.k_label) + ] # get a batch from each label else: # 'G' gen_samples_list = [ self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + for i in range(cfg.k_label) + ] return real_samples_list, gen_samples_list def prepare_eval_real_data(self): """Prepare evaluation real data, contains k_label batches of data""" with torch.no_grad(): - self.eval_real_samples = [torch.cat( - [F.one_hot(self.oracle_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) for i in range(cfg.k_label)] + self.eval_real_samples = [ + torch.cat( + [ + F.one_hot( + self.oracle_data_list[i].random_batch()["target"], + cfg.vocab_size, + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_real_samples = [self.eval_real_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_real_samples = [ + self.eval_real_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_real = [self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_real = [ + self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label) + ] def prepare_eval_fake_data(self): """Prepare evaluation fake data, contains k_label batches of data""" with torch.no_grad(): - self.eval_fake_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + label_i=i, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_fake_samples = [self.eval_fake_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.eval_fake_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_fake = [self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_fake = [ + self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label) + ] @staticmethod def get_evo_temp(cur_step): @@ -547,14 +737,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) diff --git a/instructor/oracle_data/cot_instructor.py b/instructor/oracle_data/cot_instructor.py index 0e860773..4cc7958e 100644 --- a/instructor/oracle_data/cot_instructor.py +++ b/instructor/oracle_data/cot_instructor.py @@ -24,10 +24,22 @@ def __init__(self, opt): super(CoTInstructor, self).__init__(opt) # generator, discriminator - self.gen = CoT_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = Cot_D(cfg.gen_embed_dim * 2, cfg.gen_hidden_dim * 2, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) # embed_dim and hidden_dim is larger + self.gen = CoT_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = Cot_D( + cfg.gen_embed_dim * 2, + cfg.gen_hidden_dim * 2, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) # embed_dim and hidden_dim is larger self.init_model() # Optimizer @@ -39,30 +51,32 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for epoch in progress: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.train_mediator(epoch, cfg.ADV_d_step) # Discriminator - progress.set_description('g_loss: %.4f, d_loss: %.4f' % (g_loss, d_loss)) + progress.set_description("g_loss: %.4f, d_loss: %.4f" % (g_loss, d_loss)) if epoch % cfg.adv_log_step == 0 or epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV]: epoch = %d, %s' % (epoch, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV]: epoch = %d, %s" % (epoch, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', epoch) + self._save("ADV", epoch) torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -74,16 +88,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -92,7 +110,9 @@ def adv_train_generator(self, g_step): """ g_loss = [] for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.dis(inp, self.dis.init_hidden(cfg.batch_size)) @@ -109,9 +129,13 @@ def train_mediator(self, cur_epoch, d_step): d_loss = [] for step in range(d_step): # prepare loader for training - real = list(self.oracle_data.loader)[cur_epoch % len(self.oracle_data.loader)] # traverse all real data - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + real = list(self.oracle_data.loader)[ + cur_epoch % len(self.oracle_data.loader) + ] # traverse all real data + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) if cfg.CUDA: real_inp, real_tar = real_inp.cuda(), real_tar.cuda() diff --git a/instructor/oracle_data/dgsan_instructor.py b/instructor/oracle_data/dgsan_instructor.py index 56d8a1aa..b23bc9ac 100644 --- a/instructor/oracle_data/dgsan_instructor.py +++ b/instructor/oracle_data/dgsan_instructor.py @@ -26,10 +26,22 @@ def __init__(self, opt): super(DGSANInstructor, self).__init__(opt) # generator - self.gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.old_gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.old_gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -41,11 +53,21 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -55,14 +77,14 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) progress = tqdm(range(cfg.ADV_train_epoch)) @@ -70,16 +92,21 @@ def _run(self): g_loss = self.adv_train_generator() self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) - progress.set_description('g_loss: %.4f' % g_loss) + progress.set_description("g_loss: %.4f" % g_loss) - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): self.log.info( - '[ADV]: epoch: %d, g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + "[ADV]: epoch: %d, g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -91,26 +118,35 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self): g_loss = [] gen_data = GenDataIter(self.old_gen.sample(cfg.samples_num, cfg.batch_size)) for (real, fake) in zip(self.oracle_data.loader, gen_data.loader): - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = fake['input'], fake['target'] + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = fake["input"], fake["target"] if cfg.CUDA: - real_inp, real_tar, fake_inp, fake_tar = real_inp.cuda(), real_tar.cuda(), fake_inp.cuda(), fake_tar.cuda() + real_inp, real_tar, fake_inp, fake_tar = ( + real_inp.cuda(), + real_tar.cuda(), + fake_inp.cuda(), + fake_tar.cuda(), + ) # ===Train=== real_new_pred = self.cal_pred(self.gen, real_inp, real_tar) @@ -119,8 +155,12 @@ def adv_train_generator(self): fake_old_pred = self.cal_pred(self.old_gen, fake_inp, fake_tar) eps = 0 - real_loss = -torch.sum(torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps)) - fake_loss = -torch.sum(torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps)) + real_loss = -torch.sum( + torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps) + ) + fake_loss = -torch.sum( + torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps) + ) adv_loss = real_loss + fake_loss self.optimize(self.gen_adv_opt, adv_loss) diff --git a/instructor/oracle_data/dpgan_instructor.py b/instructor/oracle_data/dpgan_instructor.py index 3c39903c..8ef2fc1a 100644 --- a/instructor/oracle_data/dpgan_instructor.py +++ b/instructor/oracle_data/dpgan_instructor.py @@ -21,10 +21,22 @@ def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DPGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = DPGAN_D( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -36,40 +48,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') - self.train_discriminator(cfg.d_step, cfg.d_epoch, 'MLE') + self.log.info("Starting Discriminator Training...") + self.train_discriminator(cfg.d_step, cfg.d_epoch, "MLE") if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -81,16 +102,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -100,13 +125,15 @@ def adv_train_generator(self, g_step): """ discount_rate = 1 total_g_loss = 0 - dis_count_list = [discount_rate ** i for i in range(cfg.max_seq_len)] - dis_count_matrix = torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + dis_count_list = [discount_rate**i for i in range(cfg.max_seq_len)] + dis_count_matrix = ( + torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + ) if cfg.CUDA: dis_count_matrix = dis_count_matrix.cuda() for step in range(g_step): - inp = self.oracle_data.random_batch()['input'] + inp = self.oracle_data.random_batch()["input"] if cfg.CUDA: inp = inp.cuda() @@ -124,9 +151,11 @@ def adv_train_generator(self, g_step): # ===Test=== self.log.info( - '[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True))) + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -147,8 +176,10 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): # ===Test=== pos_reward, neg_reward = self.eval_dis(self.dis, pos_val, neg_val) - self.log.info('[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f,' % ( - phase, step, pos_reward.item(), neg_reward.item())) + self.log.info( + "[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f," + % (phase, step, pos_reward.item(), neg_reward.item()) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) @@ -162,8 +193,8 @@ def train_dis_epoch(self, model, pos_samples, neg_samples, optimizer): num_samples = pos_samples.size(0) num_batch = num_samples // cfg.batch_size for i in range(num_batch): - pos_sample = pos_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] - neg_sample = neg_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] + pos_sample = pos_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] + neg_sample = neg_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] _, pos_reward = model.getReward(pos_sample) _, neg_reward = model.getReward(neg_sample) diff --git a/instructor/oracle_data/evogan_instructor.py b/instructor/oracle_data/evogan_instructor.py index a7d22a73..fa07c569 100644 --- a/instructor/oracle_data/evogan_instructor.py +++ b/instructor/oracle_data/evogan_instructor.py @@ -30,13 +30,39 @@ def __init__(self, opt): super(EvoGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = EvoGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = EvoGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -44,37 +70,59 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) def init_model(self): if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() - self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path, map_location='cuda:%d' % cfg.device)) + self.oracle.load_state_dict( + torch.load( + cfg.oracle_state_dict_path, map_location="cuda:%d" % cfg.device + ) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.oracle = self.oracle.cuda() self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() def load_gen(self, parent, parent_opt, mle=False): @@ -89,44 +137,72 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) # self.log.info('[ADV] epoch %d: temp = %.4f' % (adv_epoch, self.gen.temperature.item())) # self.log.info(fit_score[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() @@ -140,17 +216,21 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def evolve_generator(self, evo_g_step): @@ -167,12 +247,16 @@ def evolve_generator(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -206,7 +290,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -230,16 +316,20 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) # get evo temp - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -255,8 +345,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -308,13 +400,17 @@ def evolve_generator_population(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) # evaluate all parents - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): self.load_gen(parent, parent_opt) self.prepare_eval_fake_data() Fq, Fd, score = self.evaluation(cfg.eval_type) @@ -328,7 +424,9 @@ def evolve_generator_population(self, evo_g_step): # randomly choose a parent, variation target_idx = random.randint(0, len(self.parents) - 1) for j, criterionG in enumerate(self.G_criterion): - self.load_gen(self.parents[target_idx], self.parent_adv_opts[target_idx]) # load generator + self.load_gen( + self.parents[target_idx], self.parent_adv_opts[target_idx] + ) # load generator # Variation self.variation(evo_g_step, criterionG) @@ -344,7 +442,9 @@ def evolve_generator_population(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation.append(criterionG.loss_mode) @@ -356,8 +456,12 @@ def evolve_generator_population(self, evo_g_step): def evolve_discriminator(self, evo_d_step): total_loss = 0 for step in range(evo_d_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() - gen_samples = self.best_fake_samples[step * cfg.batch_size:(step + 1) * cfg.batch_size] + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() + gen_samples = self.best_fake_samples[ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -388,7 +492,9 @@ def variation(self, g_step, criterionG): # mixture variation: double loss rand_w = torch.rand(1).cuda() cri_1, cri_2 = criterionG - g_loss = rand_w * cri_1(self.d_out_real, d_out_fake) + (1 - rand_w) * cri_2(self.d_out_real, d_out_fake) + g_loss = rand_w * cri_1(self.d_out_real, d_out_fake) + ( + 1 - rand_w + ) * cri_2(self.d_out_real, d_out_fake) # all loss # rand_w = F.softmax(torch.rand(len(criterionG)).cuda(), dim=0) @@ -407,7 +513,9 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) + eval_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Fd @@ -416,18 +524,24 @@ def evaluation(self, eval_type): else: Fd = 0 - if eval_type == 'standard': + if eval_type == "standard": Fq = self.eval_d_out_fake.mean().cpu().item() - elif eval_type == 'rsgan': - g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') + elif eval_type == "rsgan": + g_loss, d_loss = get_losses( + self.eval_d_out_real, self.eval_d_out_fake, "rsgan" + ) Fq = d_loss.item() - elif eval_type == 'nll': + elif eval_type == "nll": if cfg.lambda_fq != 0: - Fq = -NLL.cal_nll(self.oracle, gen_data.loader, self.mle_criterion) # NLL_Oracle + Fq = -NLL.cal_nll( + self.oracle, gen_data.loader, self.mle_criterion + ) # NLL_Oracle else: Fq = 0 - elif eval_type == 'Ra': - g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() + elif eval_type == "Ra": + g_loss = torch.sigmoid( + self.eval_d_out_fake - torch.mean(self.eval_d_out_real) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -438,22 +552,31 @@ def evaluation(self, eval_type): def prepare_eval_real_data(self): with torch.no_grad(): self.eval_real_samples = torch.cat( - [F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) + [ + F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) if cfg.CUDA: self.eval_real_samples = self.eval_real_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_real = self.dis(self.eval_real_samples) def prepare_eval_fake_data(self): with torch.no_grad(): - self.eval_fake_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True) + self.eval_fake_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + ) if cfg.CUDA: self.eval_fake_samples = self.eval_fake_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_fake = self.dis(self.eval_fake_samples) @staticmethod @@ -463,14 +586,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) # three temp diff --git a/instructor/oracle_data/fixem_instructor.py b/instructor/oracle_data/fixem_instructor.py index 6f9907ff..b56ba76a 100644 --- a/instructor/oracle_data/fixem_instructor.py +++ b/instructor/oracle_data/fixem_instructor.py @@ -11,7 +11,9 @@ import config as cfg from instructor.oracle_data.instructor import BasicInstructor -from instructor.real_data.fixem_instructor import FixemGANInstructor as RealDataFixemGANInstructor +from instructor.real_data.fixem_instructor import ( + FixemGANInstructor as RealDataFixemGANInstructor, +) from utils.gan_loss import GANLoss from utils.text_process import text_file_iterator from utils.data_loader import DataSupplier, GANDataset @@ -27,14 +29,21 @@ # check target real/fake to be right (Uniform or const) # random data portion generator - data supplier sample from randomint + class FixemGANInstructor(BasicInstructor, RealDataFixemGANInstructor): def __init__(self, opt): - self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len,cfg.padding_idx, gpu=cfg.CUDA) + self.oracle = Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) if cfg.oracle_pretrain: if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() diff --git a/instructor/oracle_data/instructor.py b/instructor/oracle_data/instructor.py index 655d5e93..63af96e7 100644 --- a/instructor/oracle_data/instructor.py +++ b/instructor/oracle_data/instructor.py @@ -30,17 +30,27 @@ class BasicInstructor: def __init__(self, opt): - self.log = create_logger(__name__, silent=False, to_disk=True, - log_file=cfg.log_filename if cfg.if_test - else [cfg.log_filename, cfg.save_root + 'log.txt']) + self.log = create_logger( + __name__, + silent=False, + to_disk=True, + log_file=cfg.log_filename + if cfg.if_test + else [cfg.log_filename, cfg.save_root + "log.txt"], + ) self.sig = Signal(cfg.signal_file) self.opt = opt # oracle, generator, discriminator - self.oracle = Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.oracle_list = [Oracle(32, 32, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] + self.oracle = Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) + self.oracle_list = [ + Oracle( + 32, 32, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA + ) + for _ in range(cfg.k_label) + ] self.dis = None self.clas = None @@ -48,12 +58,18 @@ def __init__(self, opt): self.show_config() self.check_oracle() # Create Oracle models if not exist # DataLoader - self.oracle_samples = torch.load(cfg.oracle_samples_path.format(cfg.samples_num)) - self.oracle_samples_list = [torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) - for i in range(cfg.k_label)] + self.oracle_samples = torch.load( + cfg.oracle_samples_path.format(cfg.samples_num) + ) + self.oracle_samples_list = [ + torch.load(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) + for i in range(cfg.k_label) + ] self.oracle_data = GenDataIter(self.oracle_samples) - self.oracle_data_list = [GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label)] + self.oracle_data_list = [ + GenDataIter(self.oracle_samples_list[i]) for i in range(cfg.k_label) + ] # Criterion self.mle_criterion = nn.NLLLoss() @@ -61,21 +77,30 @@ def __init__(self, opt): # Metrics # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight - self.nll_oracle = NLL('NLL_oracle', weight=-3, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA) + self.nll_oracle = NLL( + "NLL_oracle", weight=-3, if_use=cfg.use_nll_oracle, gpu=cfg.CUDA + ) # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) - self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + self.nll_gen = NLL("NLL_gen", weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) - self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + self.nll_div = NLL("NLL_div", weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight - self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) + self.self_bleu = BLEU("Self-BLEU", weight=-3, gram=3, if_use=cfg.use_self_bleu) # IOC, less-better, changes in range 0.8 - 2.0, smaller weight self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.oracle_data) # dummy, add constant value to overall score self.dummy = Dummy(weight=1, value=5, if_use=True) - self.all_metrics = [self.nll_oracle, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.dummy] + self.all_metrics = [ + self.nll_oracle, + self.nll_gen, + self.nll_div, + self.self_bleu, + self.ioc, + self.dummy, + ] def _run(self): - print('Nothing to run in Basic Instructor!') + print("Nothing to run in Basic Instructor!") pass def _test(self): @@ -86,15 +111,30 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -104,7 +144,7 @@ def init_model(self): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -120,7 +160,7 @@ def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -143,7 +183,7 @@ def eval_dis(model, data_loader, criterion): total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -173,24 +213,30 @@ def optimize(opt, loss, model=None, retain_graph=False): def show_config(self): """Show parser parameters settings""" - self.log.info(100 * '=') - self.log.info('> training arguments:') + self.log.info(100 * "=") + self.log.info("> training arguments:") for arg in vars(self.opt): - self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) - self.log.info(100 * '=') + self.log.info(">>> {0}: {1}".format(arg, getattr(self.opt, arg))) + self.log.info(100 * "=") def sample_for_metrics(self): eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples) - gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size)) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 4 * cfg.batch_size) + ) return gen_data, gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): - eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size, label_i=label_i) + eval_samples = self.gen.sample( + cfg.samples_num, 4 * cfg.batch_size, label_i=label_i + ) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples) - gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i)) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) + ) return gen_data, gen_tokens, gen_tokens_s def cal_metrics(self, fmt_str=False): @@ -210,15 +256,23 @@ def cal_metrics(self, fmt_str=False): self.ioc.reset(test_text=gen_tokens) metrics = {metric.name: metric.get_score() for metric in self.all_metrics} - metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + metrics.update( + { + "Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) wandb.log(metrics) if fmt_str: - return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics_with_label() @@ -231,34 +285,56 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) - metrics = {f"label {label_i}_{metric.name}": metric.get_score() for metric in self.all_metrics} - metrics.update({f"label {label_i} Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + metrics = { + f"label {label_i}_{metric.name}": metric.get_score() + for metric in self.all_metrics + } + metrics.update( + { + f"label {label_i} Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) wandb.log(metrics) if fmt_str: - return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) return metrics def comb_metrics(self, fmt_str=False): - all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] + all_scores = [ + self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label) + ] if fmt_str: - return ', '.join([ - f'{name} = {[scores[name] for scores in all_scores]}' - for name in all_scores[0] - ]) + return ", ".join( + [ + f"{name} = {[scores[name] for scores in all_scores]}" + for name in all_scores[0] + ] + ) return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" - if phase != 'ADV': - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + if phase != "ADV": + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size) write_tensor(save_sample_path, samples) def update_temperature(self, i, N): - self.gen.temperature.data = torch.Tensor([get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) + self.gen.temperature.data = torch.Tensor( + [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)] + ) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda() @@ -268,17 +344,28 @@ def check_oracle(self): create_multi_oracle(cfg.k_label) # General text generation Oracle model - if not os.path.exists(cfg.oracle_samples_path.format(cfg.samples_num)) or not cfg.oracle_pretrain: + if ( + not os.path.exists(cfg.oracle_samples_path.format(cfg.samples_num)) + or not cfg.oracle_pretrain + ): create_oracle() # Category text generation Oracle models for i in range(cfg.k_label): - if not os.path.exists(cfg.multi_oracle_samples_path.format(i, cfg.samples_num)): + if not os.path.exists( + cfg.multi_oracle_samples_path.format(i, cfg.samples_num) + ): create_multi_oracle(cfg.k_label) break # Load Oracle state dict - self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle.load_state_dict( + torch.load( + cfg.oracle_state_dict_path, map_location="cuda:{}".format(cfg.device) + ) + ) for i in range(cfg.k_label): oracle_path = cfg.multi_oracle_state_dict_path.format(i) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:{}".format(cfg.device)) + ) diff --git a/instructor/oracle_data/jsdgan_instructor.py b/instructor/oracle_data/jsdgan_instructor.py index a207eb3a..ca4324d9 100644 --- a/instructor/oracle_data/jsdgan_instructor.py +++ b/instructor/oracle_data/jsdgan_instructor.py @@ -21,8 +21,17 @@ def __init__(self, opt): super(JSDGANInstructor, self).__init__(opt) # generator - self.gen = JSDGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = JSDGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -33,11 +42,21 @@ def init_model(self): if not os.path.exists(cfg.oracle_state_dict_path): create_oracle() self.oracle.load_state_dict( - torch.load(cfg.oracle_state_dict_path, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.oracle_state_dict_path, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.oracle = self.oracle.cuda() @@ -45,23 +64,29 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") for adv_epoch in range(cfg.ADV_train_epoch): g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -73,16 +98,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -94,7 +123,7 @@ def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): for i, data in enumerate(self.oracle_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() diff --git a/instructor/oracle_data/leakgan_instructor.py b/instructor/oracle_data/leakgan_instructor.py index 0ec40663..91d0fcf3 100644 --- a/instructor/oracle_data/leakgan_instructor.py +++ b/instructor/oracle_data/leakgan_instructor.py @@ -24,9 +24,19 @@ def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) - self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = LeakGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + cfg.goal_size, + cfg.step_size, + cfg.CUDA, + ) + self.dis = LeakGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # optimizer @@ -39,48 +49,63 @@ def __init__(self, opt): def _run(self): for inter_num in range(cfg.inter_epoch): - self.log.info('>>> Interleaved Round %d...' % inter_num) + self.log.info(">>> Interleaved Round %d..." % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # ===DISCRIMINATOR PRE-TRAINING=== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format( + cfg.pretrained_dis_path + ) + ) # ===GENERATOR MLE TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + ) + ) else: - self.log.info('>>> Stop by pre_signal! Skip to adversarial training...') + self.log.info(">>> Stop by pre_signal! Skip to adversarial training...") break # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -98,7 +123,7 @@ def pretrain_generator(self, epochs): # ===Train=== for i, data in enumerate(self.oracle_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -111,13 +136,20 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % ( - epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s" + % ( + epoch, + pre_mana_loss, + pre_work_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step, current_k=0): @@ -131,13 +163,15 @@ def adv_train_generator(self, g_step, current_k=0): adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): - gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis, - train=True) # !!! train=True, the only place + gen_samples = self.gen.sample( + cfg.batch_size, cfg.batch_size, self.dis, train=True + ) # !!! train=True, the only place inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA) # ===Train=== - rewards = rollout_func.get_reward_leakgan(target, cfg.rollout_num, self.dis, - current_k).cpu() # reward with MC search + rewards = rollout_func.get_reward_leakgan( + target, cfg.rollout_num, self.dis, current_k + ).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss(target, rewards, self.dis) # update parameters @@ -145,10 +179,16 @@ def adv_train_generator(self, g_step, current_k=0): adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # ===Test=== - self.log.info('[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % ( - adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) - - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + self.log.info( + "[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s" + % ( + adv_mana_loss / g_step, + adv_work_loss / g_step, + self.cal_metrics(fmt_str=True), + ) + ) + + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -161,23 +201,32 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for step in range(d_step): # prepare loader for training - pos_samples = self.oracle.sample(cfg.samples_num, cfg.batch_size) # re-sample the Oracle Data + pos_samples = self.oracle.sample( + cfg.samples_num, cfg.batch_size + ) # re-sample the Oracle Data neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) dis_data = DisDataIter(pos_samples, neg_samples) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) def cal_metrics(self, fmt_str=False): # Prepare data for evaluation - gen_data = GenDataIter(self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)) + gen_data = GenDataIter( + self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) + ) # Reset metrics self.nll_oracle.reset(self.oracle, gen_data.loader) @@ -185,12 +234,22 @@ def cal_metrics(self, fmt_str=False): self.nll_div.reset(self.gen, gen_data.loader, leak_dis=self.dis) if fmt_str: - return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) + return ", ".join( + [ + "%s = %s" % (metric.name, metric.get_score()) + for metric in self.all_metrics + ] + ) else: return [metric.get_score() for metric in self.all_metrics] def _save(self, phase, epoch): - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tensor(save_sample_path, samples) diff --git a/instructor/oracle_data/maligan_instructor.py b/instructor/oracle_data/maligan_instructor.py index 8d7478f7..a7eef286 100644 --- a/instructor/oracle_data/maligan_instructor.py +++ b/instructor/oracle_data/maligan_instructor.py @@ -24,9 +24,17 @@ def __init__(self, opt): super(MaliGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = MaliGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = MaliGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = MaliGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = MaliGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -38,40 +46,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,16 +100,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -101,7 +122,9 @@ def adv_train_generator(self, g_step): """ total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.get_mali_reward(target) @@ -110,9 +133,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -131,13 +157,18 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/oracle_data/relgan_instructor.py b/instructor/oracle_data/relgan_instructor.py index 932c51ee..1b5bf5df 100644 --- a/instructor/oracle_data/relgan_instructor.py +++ b/instructor/oracle_data/relgan_instructor.py @@ -23,10 +23,25 @@ def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, - gpu=cfg.CUDA) + self.gen = RelGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = RelGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -38,39 +53,50 @@ def __init__(self, opt): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator - self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature + self.update_temperature( + adv_epoch, cfg.ADV_train_epoch + ) # update temperature progress.set_description( - 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) + "g_loss: %.4f, d_loss: %.4f, temperature: %.4f" + % (g_loss, d_loss, self.gen.temperature) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( - adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s" + % (adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) progress.close() break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() @@ -84,23 +110,29 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -118,7 +150,9 @@ def adv_train_generator(self, g_step): def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): - real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.oracle_data.random_batch()["target"], cfg.vocab_size + ).float() gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -134,7 +168,9 @@ def adv_train_discriminator(self, d_step): return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): - self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) + self.gen.temperature = get_fixed_temperature( + cfg.temperature, i, N, cfg.temp_adpt + ) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): diff --git a/instructor/oracle_data/sentigan_instructor.py b/instructor/oracle_data/sentigan_instructor.py index 42be59a3..1c72b697 100644 --- a/instructor/oracle_data/sentigan_instructor.py +++ b/instructor/oracle_data/sentigan_instructor.py @@ -28,16 +28,42 @@ def __init__(self, opt): super(SentiGANInstructor, self).__init__(opt) # generator, discriminator - self.oracle_list = [Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - - self.gen_list = [SentiGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - self.dis = SentiGAN_D(cfg.k_label, cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.oracle_list = [ + Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + + self.gen_list = [ + SentiGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + self.dis = SentiGAN_D( + cfg.k_label, + cfg.dis_embed_dim, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer - self.gen_opt_list = [optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list] + self.gen_opt_list = [ + optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list + ] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) def init_model(self): @@ -46,17 +72,32 @@ def init_model(self): oracle_path = cfg.multi_oracle_state_dict_path.format(i) if not os.path.exists(oracle_path): create_multi_oracle(cfg.k_label) - self.oracle_list[i].load_state_dict(torch.load(oracle_path, map_location='cuda:{}'.format(cfg.device))) + self.oracle_list[i].load_state_dict( + torch.load(oracle_path, map_location="cuda:{}".format(cfg.device)) + ) if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.k_label): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) self.gen_list[i].load_state_dict( - torch.load(cfg.pretrained_gen_path + '%d' % i, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.pretrained_gen_path + "%d" % i, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -67,41 +108,57 @@ def init_model(self): def _run(self): # ===PRE-TRAIN GENERATOR=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), cfg.pretrained_gen_path + '%d' % i) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen_list[i].state_dict(), + cfg.pretrained_gen_path + "%d" % i, + ) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s', self.comb_metrics(fmt_str=True)) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s", self.comb_metrics(fmt_str=True)) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -114,18 +171,24 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: for i in range(cfg.k_label): - pre_loss = self.train_gen_epoch(self.gen_list[i], self.oracle_data_list[i].loader, - self.mle_criterion, self.gen_opt_list[i]) + pre_loss = self.train_gen_epoch( + self.gen_list[i], + self.oracle_data_list[i].loader, + self.mle_criterion, + self.gen_opt_list[i], + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: if i == cfg.k_label - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -137,7 +200,10 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), + gpu=cfg.CUDA, + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -146,9 +212,9 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True)) + self.log.info("[ADV-GEN]: %s", self.comb_metrics(fmt_str=True)) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -162,32 +228,43 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): fake_samples = [] for i in range(cfg.k_label): real_samples.append(self.oracle_samples_list[i]) - fake_samples.append(self.gen_list[i].sample(cfg.samples_num // cfg.k_label, 8 * cfg.batch_size)) + fake_samples.append( + self.gen_list[i].sample( + cfg.samples_num // cfg.k_label, 8 * cfg.batch_size + ) + ) dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples dis_data = CatClasDataIter(dis_samples_list) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f" + % (phase, step, d_loss, train_acc) + ) - if cfg.if_save and not cfg.if_test and phase == 'MLE': + if cfg.if_save and not cfg.if_test and phase == "MLE": torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" # Prepare data for evaluation - eval_samples = self.gen_list[label_i].sample(cfg.samples_num, 8 * cfg.batch_size) + eval_samples = self.gen_list[label_i].sample( + cfg.samples_num, 8 * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Reset metrics self.nll_oracle.reset(self.oracle_list[label_i], gen_data.loader) - self.nll_gen.reset(self.gen_list[label_i], self.oracle_data_list[label_i].loader) + self.nll_gen.reset( + self.gen_list[label_i], self.oracle_data_list[label_i].loader + ) self.nll_div.reset(self.gen_list[label_i], gen_data.loader) return [metric.get_score() for metric in self.all_metrics] @@ -195,8 +272,13 @@ def cal_metrics_with_label(self, label_i): def _save(self, phase, epoch): """Save model state dict and generator's samples""" for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), - cfg.save_model_root + 'gen{}_{}_{:05d}.pt'.format(i, phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_d{}_{}_{:05d}.txt'.format(i, phase, epoch) + torch.save( + self.gen_list[i].state_dict(), + cfg.save_model_root + "gen{}_{}_{:05d}.pt".format(i, phase, epoch), + ) + save_sample_path = ( + cfg.save_samples_root + + "samples_d{}_{}_{:05d}.txt".format(i, phase, epoch) + ) samples = self.gen_list[i].sample(cfg.batch_size, cfg.batch_size) write_tensor(save_sample_path, samples) diff --git a/instructor/oracle_data/seqgan_instructor.py b/instructor/oracle_data/seqgan_instructor.py index b7f64676..1a5026a5 100644 --- a/instructor/oracle_data/seqgan_instructor.py +++ b/instructor/oracle_data/seqgan_instructor.py @@ -23,9 +23,17 @@ def __init__(self, opt): super(SeqGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = SeqGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = SeqGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = SeqGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = SeqGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -37,40 +45,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -82,16 +99,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -132,13 +158,18 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - _, eval_acc = self.eval_dis(self.dis, dis_eval_data.loader, self.dis_criterion) - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f,' % ( - phase, step, d_loss, train_acc, eval_acc)) + _, eval_acc = self.eval_dis( + self.dis, dis_eval_data.loader, self.dis_criterion + ) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f, eval_acc = %.4f," + % (phase, step, d_loss, train_acc, eval_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/real_data/catgan_instructor.py b/instructor/real_data/catgan_instructor.py index e6bd3d9a..0afa4a67 100644 --- a/instructor/real_data/catgan_instructor.py +++ b/instructor/real_data/catgan_instructor.py @@ -27,21 +27,54 @@ class CatGANInstructor(BasicInstructor): - def __init__(self, opt): super(CatGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [CatGAN_G(cfg.k_label, cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, - cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, - gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = CatGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) - self.clas = CatGAN_C(cfg.k_label, cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.extend_vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + CatGAN_G( + cfg.k_label, + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = CatGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.clas = CatGAN_C( + cfg.k_label, + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.extend_vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -50,14 +83,19 @@ def __init__(self, opt): self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) self.clas_opt = optim.Adam(self.clas.parameters(), lr=cfg.clas_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) # DataLoader self.all_train_data = CatGenDataIter(self.train_samples_list) @@ -68,13 +106,21 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() self.clas = self.clas.cuda() @@ -90,19 +136,29 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===Pre-train Classifier with real data=== if cfg.use_clas_acc: - self.log.info('Start training Classifier...') + self.log.info("Start training Classifier...") self.train_classifier(cfg.PRE_clas_epoch) # ===Pre-train Generator=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===Adv-train=== progress = tqdm(range(cfg.ADV_train_epoch)) @@ -110,27 +166,45 @@ def _run(self): if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # ===Test=== - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss: %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.comb_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: for label_i in range(cfg.k_label): - self._save('ADV', adv_epoch, label_i) + self._save("ADV", adv_epoch, label_i) def _test(self): - self.log.debug('>>> Begin test...') + self.log.debug(">>> Begin test...") self._run() pass @@ -141,17 +215,20 @@ def pretrain_generator(self, epochs): """ for epoch in range(epochs): # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.all_train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.all_train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if not cfg.if_test and cfg.if_save: for label_i in range(cfg.k_label): - self._save('MLE', epoch, label_i) + self._save("MLE", epoch, label_i) def evolve_generator(self, evo_g_step): # evaluation real data @@ -167,13 +244,21 @@ def evolve_generator(self, evo_g_step): # all child share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -198,7 +283,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -222,17 +309,25 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for i in range(cfg.k_label)] + real_samples = [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], cfg.vocab_size + ).float() + for i in range(cfg.k_label) + ] if cfg.CUDA: real_samples = [real_samples[i].cuda() for i in range(cfg.k_label)] - self.d_out_real = [self.dis(real_samples[i]) for i in range(cfg.k_label)] # d_out_real for each label + self.d_out_real = [ + self.dis(real_samples[i]) for i in range(cfg.k_label) + ] # d_out_real for each label - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -248,8 +343,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -287,15 +384,21 @@ def evolve_discriminator(self, evo_d_step): global dc_loss, dd_loss, d_loss total_loss = [] - all_gen_samples_list = list(map(self.merge, *self.best_fake_samples)) # merge each label of data - self.all_gen_samples_list = self.shuffle_eval_samples(all_gen_samples_list) # shuffle data + all_gen_samples_list = list( + map(self.merge, *self.best_fake_samples) + ) # merge each label of data + self.all_gen_samples_list = self.shuffle_eval_samples( + all_gen_samples_list + ) # shuffle data for step in range(evo_d_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('D', step) + dis_real_samples, dis_gen_samples = self.prepare_train_data("D", step) d_loss = 0 all_d_out_real = [] all_d_out_fake = [] - for (real_samples, fake_samples) in zip(dis_real_samples, dis_gen_samples): # for each label samples + for (real_samples, fake_samples) in zip( + dis_real_samples, dis_gen_samples + ): # for each label samples d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) d_loss += self.D_criterion(d_out_real, d_out_fake) @@ -320,14 +423,16 @@ def variation(self, g_step, criterionG): """Optimize one child (Generator)""" total_loss = [] for step in range(g_step): - dis_real_samples, dis_gen_samples = self.prepare_train_data('G') + dis_real_samples, dis_gen_samples = self.prepare_train_data("G") # ===Train=== g_loss = 0 all_d_out_real = [] all_d_out_fake = [] # for i, (real_samples, fake_samples) in enumerate(zip(dis_real_samples, dis_gen_samples)): - for i, (d_out_real, fake_samples) in enumerate(zip(self.d_out_real, dis_gen_samples)): # share real + for i, (d_out_real, fake_samples) in enumerate( + zip(self.d_out_real, dis_gen_samples) + ): # share real # d_out_real = self.dis(real_samples) d_out_fake = self.dis(fake_samples) g_loss += criterionG(d_out_real, d_out_fake) @@ -350,30 +455,40 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i) for i - in range(cfg.k_label)] + eval_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size, label_i=i + ) + for i in range(cfg.k_label) + ] # Fd if cfg.lambda_fd != 0: nll_div = [] for label_i in range(cfg.k_label): gen_data = GenDataIter(eval_samples[label_i]) - nll_div.append(NLL.cal_nll_with_label(self.gen, gen_data.loader, label_i, self.mle_criterion)) + nll_div.append( + NLL.cal_nll_with_label( + self.gen, gen_data.loader, label_i, self.mle_criterion + ) + ) Fd = sum(nll_div) else: Fd = 0 # Fq - if 'bleu' in eval_type: + if "bleu" in eval_type: bleu_score = [] for i in range(cfg.k_label): bleu_score.append(self.bleu[i].get_score(given_gram=int(eval_type[-1]))) Fq = sum(bleu_score) - elif 'Ra' in eval_type: + elif "Ra" in eval_type: g_loss = 0 for i in range(cfg.k_label): - g_loss += torch.sigmoid(self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i])).sum() + g_loss += torch.sigmoid( + self.eval_d_out_fake[i] - torch.mean(self.eval_d_out_real[i]) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -384,7 +499,7 @@ def evaluation(self, eval_type): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target, label = data['input'], data['target'], data['label'] + inp, target, label = data["input"], data["target"], data["label"] if cfg.CUDA: inp, target, label = inp.cuda(), target.cuda(), label.cuda() @@ -397,8 +512,13 @@ def train_gen_epoch(self, model, data_loader, criterion, optimizer): def _save(self, phase, epoch, label_i=None): assert type(label_i) == int - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_c{}_{}_{:05d}.txt'.format(label_i, phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_c{}_{}_{:05d}.txt".format( + label_i, phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, label_i=label_i) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) @@ -409,50 +529,86 @@ def merge(*args): def shuffle_eval_samples(self, all_eval_samples): temp = [] for i in range(cfg.k_label): - temp.append(all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))]) + temp.append( + all_eval_samples[i][torch.randperm(all_eval_samples[i].size(0))] + ) return temp def prepare_train_data(self, which, step=None): """Prepare train data for both Generator and Discriminator, each samples_list contains k_label batches of data""" - assert which == 'D' or which == 'G', 'only support for D and G!!' + assert which == "D" or which == "G", "only support for D and G!!" real_samples_list = [ - F.one_hot(self.train_data_list[i].random_batch()['target'][:cfg.batch_size], - cfg.vocab_size).float().cuda() - for i in range(cfg.k_label)] - if which == 'D': - assert step is not None, 'missing step' - gen_samples_list = [self.all_gen_samples_list[i][step * cfg.batch_size:(step + 1) * cfg.batch_size] - for i in range(cfg.k_label)] # get a batch from each label + F.one_hot( + self.train_data_list[i].random_batch()["target"][: cfg.batch_size], + cfg.vocab_size, + ) + .float() + .cuda() + for i in range(cfg.k_label) + ] + if which == "D": + assert step is not None, "missing step" + gen_samples_list = [ + self.all_gen_samples_list[i][ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] + for i in range(cfg.k_label) + ] # get a batch from each label else: # 'G' gen_samples_list = [ self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + for i in range(cfg.k_label) + ] return real_samples_list, gen_samples_list def prepare_eval_real_data(self): """Prepare evaluation real data, contains k_label batches of data""" with torch.no_grad(): - self.eval_real_samples = [torch.cat( - [F.one_hot(self.train_data_list[i].random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) for i in range(cfg.k_label)] + self.eval_real_samples = [ + torch.cat( + [ + F.one_hot( + self.train_data_list[i].random_batch()["target"], + cfg.vocab_size, + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_real_samples = [self.eval_real_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_real_samples = [ + self.eval_real_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_real = [self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_real = [ + self.dis(self.eval_real_samples[i]) for i in range(cfg.k_label) + ] def prepare_eval_fake_data(self): """Prepare evaluation fake data, contains k_label batches of data""" with torch.no_grad(): - self.eval_fake_samples = [self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True, label_i=i) - for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + label_i=i, + ) + for i in range(cfg.k_label) + ] if cfg.CUDA: - self.eval_fake_samples = [self.eval_fake_samples[i].cuda() for i in range(cfg.k_label)] + self.eval_fake_samples = [ + self.eval_fake_samples[i].cuda() for i in range(cfg.k_label) + ] - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': - self.eval_d_out_fake = [self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label)] + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": + self.eval_d_out_fake = [ + self.dis(self.eval_fake_samples[i]) for i in range(cfg.k_label) + ] @staticmethod def get_evo_temp(cur_step): @@ -461,14 +617,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) diff --git a/instructor/real_data/cot_instructor.py b/instructor/real_data/cot_instructor.py index e92fb9eb..01ba17e7 100644 --- a/instructor/real_data/cot_instructor.py +++ b/instructor/real_data/cot_instructor.py @@ -25,10 +25,22 @@ def __init__(self, opt): super(CoTInstructor, self).__init__(opt) # generator, discriminator - self.gen = CoT_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = Cot_D(cfg.gen_embed_dim * 2, cfg.gen_hidden_dim * 2, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) # embed_dim and hidden_dim is larger + self.gen = CoT_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = Cot_D( + cfg.gen_embed_dim * 2, + cfg.gen_hidden_dim * 2, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) # embed_dim and hidden_dim is larger self.init_model() # Optimizer @@ -40,30 +52,32 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for epoch in progress: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.train_mediator(epoch, cfg.ADV_d_step) # Discriminator - progress.set_description('g_loss: %.4f, d_loss: %.4f' % (g_loss, d_loss)) + progress.set_description("g_loss: %.4f, d_loss: %.4f" % (g_loss, d_loss)) if epoch % cfg.adv_log_step == 0 or epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV]: epoch = %d, %s' % (epoch, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV]: epoch = %d, %s" % (epoch, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', epoch) + self._save("ADV", epoch) torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -75,16 +89,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -93,7 +111,9 @@ def adv_train_generator(self, g_step): """ g_loss = [] for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.dis(inp, self.dis.init_hidden(cfg.batch_size)) @@ -110,9 +130,13 @@ def train_mediator(self, cur_epoch, d_step): d_loss = [] for step in range(d_step): # prepare loader for training - real = list(self.train_data.loader)[cur_epoch % len(self.train_data.loader)] # traverse all real data - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + real = list(self.train_data.loader)[ + cur_epoch % len(self.train_data.loader) + ] # traverse all real data + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) if cfg.CUDA: real_inp, real_tar = real_inp.cuda(), real_tar.cuda() diff --git a/instructor/real_data/dgsan_instructor.py b/instructor/real_data/dgsan_instructor.py index 7018b452..9c250d1b 100644 --- a/instructor/real_data/dgsan_instructor.py +++ b/instructor/real_data/dgsan_instructor.py @@ -25,10 +25,22 @@ def __init__(self, opt): super(DGSANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.old_gen = DGSAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.old_gen = DGSAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -37,8 +49,14 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -47,14 +65,14 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) progress = tqdm(range(cfg.ADV_train_epoch)) @@ -62,16 +80,21 @@ def _run(self): g_loss = self.adv_train_generator() self.old_gen.load_state_dict(copy.deepcopy(self.gen.state_dict())) - progress.set_description('g_loss: %.4f' % g_loss) + progress.set_description("g_loss: %.4f" % g_loss) - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): self.log.info( - '[ADV]: epoch: %d, g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + "[ADV]: epoch: %d, g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,26 +106,35 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self): g_loss = [] gen_data = GenDataIter(self.old_gen.sample(cfg.samples_num, cfg.batch_size)) for (real, fake) in zip(self.train_data.loader, gen_data.loader): - real_inp, real_tar = real['input'], real['target'] - fake_inp, fake_tar = fake['input'], fake['target'] + real_inp, real_tar = real["input"], real["target"] + fake_inp, fake_tar = fake["input"], fake["target"] if cfg.CUDA: - real_inp, real_tar, fake_inp, fake_tar = real_inp.cuda(), real_tar.cuda(), fake_inp.cuda(), fake_tar.cuda() + real_inp, real_tar, fake_inp, fake_tar = ( + real_inp.cuda(), + real_tar.cuda(), + fake_inp.cuda(), + fake_tar.cuda(), + ) # ===Train=== real_new_pred = self.cal_pred(self.gen, real_inp, real_tar) @@ -111,8 +143,12 @@ def adv_train_generator(self): fake_old_pred = self.cal_pred(self.old_gen, fake_inp, fake_tar) eps = 0 - real_loss = -torch.sum(torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps)) - fake_loss = -torch.sum(torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps)) + real_loss = -torch.sum( + torch.log(1 / (1 + real_old_pred / (real_new_pred + eps) + eps) + eps) + ) + fake_loss = -torch.sum( + torch.log(1 / (1 + fake_new_pred / (fake_old_pred + eps) + eps) + eps) + ) adv_loss = real_loss + fake_loss self.optimize(self.gen_adv_opt, adv_loss) diff --git a/instructor/real_data/dpgan_instructor.py b/instructor/real_data/dpgan_instructor.py index 071a382c..d4ad3293 100644 --- a/instructor/real_data/dpgan_instructor.py +++ b/instructor/real_data/dpgan_instructor.py @@ -21,10 +21,22 @@ def __init__(self, opt): super(DPGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = DPGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = DPGAN_D(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = DPGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = DPGAN_D( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -36,40 +48,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') - self.train_discriminator(cfg.d_step, cfg.d_epoch, 'MLE') + self.log.info("Starting Discriminator Training...") + self.train_discriminator(cfg.d_step, cfg.d_epoch, "MLE") if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -81,16 +102,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -100,13 +125,15 @@ def adv_train_generator(self, g_step): """ discount_rate = 1 total_g_loss = 0 - dis_count_list = [discount_rate ** i for i in range(cfg.max_seq_len)] - dis_count_matrix = torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + dis_count_list = [discount_rate**i for i in range(cfg.max_seq_len)] + dis_count_matrix = ( + torch.Tensor(dis_count_list).unsqueeze(0).repeat(cfg.batch_size, 1) + ) if cfg.CUDA: dis_count_matrix = dis_count_matrix.cuda() for step in range(g_step): - inp = self.train_data.random_batch()['input'] + inp = self.train_data.random_batch()["input"] if cfg.CUDA: inp = inp.cuda() @@ -124,9 +151,11 @@ def adv_train_generator(self, g_step): # ===Test=== self.log.info( - '[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True))) + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss / (g_step * cfg.batch_size), self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -140,11 +169,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): pos_reward, neg_reward = 0, 0 for epoch in range(d_epoch): # ===Train=== - pos_reward, neg_reward = self.train_dis_epoch(self.dis, pos_samples, neg_samples, self.dis_opt) + pos_reward, neg_reward = self.train_dis_epoch( + self.dis, pos_samples, neg_samples, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f,' % ( - phase, step, pos_reward, neg_reward)) + self.log.info( + "[%s-DIS] d_step %d: pos_reward = %.4f, neg_reward = %.4f," + % (phase, step, pos_reward, neg_reward) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) @@ -159,8 +192,8 @@ def train_dis_epoch(self, model, pos_samples, neg_samples, optimizer): num_samples = pos_samples.size(0) num_batch = num_samples // cfg.batch_size for i in range(num_batch): - pos_sample = pos_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] - neg_sample = neg_samples[i * cfg.batch_size: (i + 1) * cfg.batch_size] + pos_sample = pos_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] + neg_sample = neg_samples[i * cfg.batch_size : (i + 1) * cfg.batch_size] _, pos_reward = model.getReward(pos_sample) _, neg_reward = model.getReward(neg_sample) diff --git a/instructor/real_data/evogan_instructor.py b/instructor/real_data/evogan_instructor.py index eef84bef..faf2cd13 100644 --- a/instructor/real_data/evogan_instructor.py +++ b/instructor/real_data/evogan_instructor.py @@ -30,13 +30,39 @@ def __init__(self, opt): super(EvoGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.parents = [EvoGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA).state_dict() - for _ in range(cfg.n_parent)] # list of Generator state_dict - self.dis = EvoGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen = EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.parents = [ + EvoGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ).state_dict() + for _ in range(cfg.n_parent) + ] # list of Generator state_dict + self.dis = EvoGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() @@ -44,30 +70,48 @@ def __init__(self, opt): self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr) self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr) self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) - self.parent_mle_opts = [copy.deepcopy(self.gen_opt.state_dict()) - for _ in range(cfg.n_parent)] - self.parent_adv_opts = [copy.deepcopy(self.gen_adv_opt.state_dict()) - for _ in range(cfg.n_parent)] # list of optimizer state dict + self.parent_mle_opts = [ + copy.deepcopy(self.gen_opt.state_dict()) for _ in range(cfg.n_parent) + ] + self.parent_adv_opts = [ + copy.deepcopy(self.gen_adv_opt.state_dict()) for _ in range(cfg.n_parent) + ] # list of optimizer state dict # Criterion - self.G_criterion = [GANLoss(loss_mode, 'G', cfg.d_type, CUDA=cfg.CUDA) for loss_mode in cfg.mu_type.split()] - self.D_criterion = GANLoss(cfg.loss_type, 'D', cfg.d_type, CUDA=cfg.CUDA) + self.G_criterion = [ + GANLoss(loss_mode, "G", cfg.d_type, CUDA=cfg.CUDA) + for loss_mode in cfg.mu_type.split() + ] + self.D_criterion = GANLoss(cfg.loss_type, "D", cfg.d_type, CUDA=cfg.CUDA) def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.n_parent): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) - self.parents[i] = torch.load(cfg.pretrained_gen_path + '%d' % 0, map_location='cpu') + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) + self.parents[i] = torch.load( + cfg.pretrained_gen_path + "%d" % 0, map_location="cpu" + ) if cfg.CUDA: self.gen = self.gen.cuda() if cfg.multi_gpu: - self.dis = torch.nn.parallel.DataParallel(self.dis, device_ids=cfg.devices) + self.dis = torch.nn.parallel.DataParallel( + self.dis, device_ids=cfg.devices + ) self.dis = self.dis.cuda() def load_gen(self, parent, parent_opt, mle=False): @@ -82,42 +126,70 @@ def load_gen(self, parent, parent_opt, mle=False): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_mle_opts)): - self.log.info('Starting Generator-{} MLE Training...'.format(i)) + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_mle_opts) + ): + self.log.info("Starting Generator-{} MLE Training...".format(i)) self.load_gen(parent, parent_opt, mle=True) # load state dict self.pretrain_generator(cfg.MLE_train_epoch) - self.parents[i] = copy.deepcopy(self.gen.state_dict()) # save state dict + self.parents[i] = copy.deepcopy( + self.gen.state_dict() + ) # save state dict if cfg.if_save and not cfg.if_test: - torch.save(self.gen.state_dict(), cfg.pretrained_gen_path + '%d' % i) - self.log.info('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen.state_dict(), cfg.pretrained_gen_path + "%d" % i + ) + self.log.info( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = trange(cfg.ADV_train_epoch) for adv_epoch in progress: if cfg.temperature == 1: score, fit_score, select_mu = self.evolve_generator(cfg.ADV_g_step) else: # evolve with temperature - score, fit_score, select_mu = self.evolve_generator_with_temp(adv_epoch, cfg.ADV_g_step) + score, fit_score, select_mu = self.evolve_generator_with_temp( + adv_epoch, cfg.ADV_g_step + ) d_loss = self.evolve_discriminator(cfg.ADV_d_step) best_id = int(np.argmax(score)) - progress.set_description('mu: %s, d_loss = %.4f, temp = %.4f' % ( - ' '.join(select_mu), d_loss, self.parents[best_id]['temperature'].item())) + progress.set_description( + "mu: %s, d_loss = %.4f, temp = %.4f" + % ( + " ".join(select_mu), + d_loss, + self.parents[best_id]["temperature"].item(), + ) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): best_id = int(np.argmax(score)) self.load_gen(self.parents[best_id], self.parent_adv_opts[best_id]) - self.log.info('[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s' % ( - adv_epoch, self.gen.temperature.item(), d_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV] epoch %d: temp = %.4f, d_loss = %.4f, %s" + % ( + adv_epoch, + self.gen.temperature.item(), + d_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -130,17 +202,21 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def evolve_generator(self, evo_g_step): @@ -157,12 +233,16 @@ def evolve_generator(self, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): # Variation self.load_gen(parent, parent_opt) # load state dict to self.gen @@ -188,7 +268,9 @@ def evolve_generator(self, evo_g_step): best_score[id_replace] = score best_fit[id_replace] = [Fq, Fd, score] best_child[id_replace] = copy.deepcopy(self.gen.state_dict()) - best_child_opt[id_replace] = copy.deepcopy(self.gen_adv_opt.state_dict()) + best_child_opt[id_replace] = copy.deepcopy( + self.gen_adv_opt.state_dict() + ) best_fake_samples[id_replace] = self.eval_fake_samples selected_mutation[id_replace] = criterionG.loss_mode count += 1 @@ -212,16 +294,20 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # all children share the same real data output from Discriminator with torch.no_grad(): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() if cfg.CUDA: real_samples = real_samples.cuda() self.d_out_real = self.dis(real_samples) - for i, (parent, parent_opt) in enumerate(zip(self.parents, self.parent_adv_opts)): + for i, (parent, parent_opt) in enumerate( + zip(self.parents, self.parent_adv_opts) + ): for j, criterionG in enumerate(self.G_criterion): all_temp = self.get_evo_temp(cur_adv_step) # get evo temp - temp_score = float('-inf') + temp_score = float("-inf") temp_fit = None temp_child = None temp_child_opt = None @@ -237,8 +323,10 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): # Evaluation self.prepare_eval_fake_data() # evaluation fake data - _, _, t_score = self.evaluation('Ra') # for temp evolutionary - loss_Fq, loss_Fd, loss_score = self.evaluation(cfg.eval_type) # for loss evolutionary + _, _, t_score = self.evaluation("Ra") # for temp evolutionary + loss_Fq, loss_Fd, loss_score = self.evaluation( + cfg.eval_type + ) # for loss evolutionary if t_score > temp_score: temp_score = loss_score @@ -275,8 +363,12 @@ def evolve_generator_with_temp(self, cur_adv_step, evo_g_step): def evolve_discriminator(self, evo_d_step): total_loss = 0 for step in range(evo_d_step): - real_samples = F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() - gen_samples = self.best_fake_samples[step * cfg.batch_size:(step + 1) * cfg.batch_size] + real_samples = F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() + gen_samples = self.best_fake_samples[ + step * cfg.batch_size : (step + 1) * cfg.batch_size + ] if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -312,7 +404,9 @@ def variation(self, g_step, criterionG): def evaluation(self, eval_type): """Evaluation all children, update child score. Note that the eval data should be the same""" - eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size) + eval_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, cfg.max_bn * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) # Fd @@ -322,20 +416,26 @@ def evaluation(self, eval_type): Fd = 0 # Fq - if eval_type == 'standard': + if eval_type == "standard": Fq = self.eval_d_out_fake.mean().cpu().item() - elif eval_type == 'rsgan': - g_loss, d_loss = get_losses(self.eval_d_out_real, self.eval_d_out_fake, 'rsgan') + elif eval_type == "rsgan": + g_loss, d_loss = get_losses( + self.eval_d_out_real, self.eval_d_out_fake, "rsgan" + ) Fq = d_loss.item() - elif 'bleu' in eval_type: - self.bleu.reset(test_text=tensor_to_tokens(eval_samples, self.idx2word_dict)) + elif "bleu" in eval_type: + self.bleu.reset( + test_text=tensor_to_tokens(eval_samples, self.idx2word_dict) + ) if cfg.lambda_fq != 0: Fq = self.bleu.get_score(given_gram=int(eval_type[-1])) else: Fq = 0 - elif 'Ra' in eval_type: - g_loss = torch.sigmoid(self.eval_d_out_fake - torch.mean(self.eval_d_out_real)).sum() + elif "Ra" in eval_type: + g_loss = torch.sigmoid( + self.eval_d_out_fake - torch.mean(self.eval_d_out_real) + ).sum() Fq = g_loss.item() else: raise NotImplementedError("Evaluation '%s' is not implemented" % eval_type) @@ -346,22 +446,31 @@ def evaluation(self, eval_type): def prepare_eval_real_data(self): with torch.no_grad(): self.eval_real_samples = torch.cat( - [F.one_hot(self.train_data.random_batch()['target'], cfg.vocab_size).float() - for _ in range(cfg.eval_b_num)], dim=0) + [ + F.one_hot( + self.train_data.random_batch()["target"], cfg.vocab_size + ).float() + for _ in range(cfg.eval_b_num) + ], + dim=0, + ) if cfg.CUDA: self.eval_real_samples = self.eval_real_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_real = self.dis(self.eval_real_samples) def prepare_eval_fake_data(self): with torch.no_grad(): - self.eval_fake_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size, - cfg.eval_b_num * cfg.batch_size, one_hot=True) + self.eval_fake_samples = self.gen.sample( + cfg.eval_b_num * cfg.batch_size, + cfg.eval_b_num * cfg.batch_size, + one_hot=True, + ) if cfg.CUDA: self.eval_fake_samples = self.eval_fake_samples.cuda() - if cfg.eval_type == 'rsgan' or cfg.eval_type == 'Ra': + if cfg.eval_type == "rsgan" or cfg.eval_type == "Ra": self.eval_d_out_fake = self.dis(self.eval_fake_samples) @staticmethod @@ -371,14 +480,30 @@ def get_evo_temp(cur_step): all_temp = list() # all_temp.append(get_fixed_temperature(1.0, 0, 0, 'no')) # temp=1.0 - all_temp.append(get_fixed_temperature(cfg.temperature, cur_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) # current step all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step + cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) # current step + all_temp.append( + get_fixed_temperature( + cfg.temperature, + cur_step + cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) if cur_step > cfg.evo_temp_step: all_temp.append( - get_fixed_temperature(cfg.temperature, cur_step - cfg.evo_temp_step, cfg.ADV_train_epoch, - random.choice(mu_temp_type))) + get_fixed_temperature( + cfg.temperature, + cur_step - cfg.evo_temp_step, + cfg.ADV_train_epoch, + random.choice(mu_temp_type), + ) + ) return torch.Tensor(all_temp) # three temp diff --git a/instructor/real_data/fixem_instructor.py b/instructor/real_data/fixem_instructor.py index d37253c9..60a0b681 100644 --- a/instructor/real_data/fixem_instructor.py +++ b/instructor/real_data/fixem_instructor.py @@ -31,35 +31,59 @@ def __init__(self, opt): w2v = load_embedding(cfg.pretrain_embedding_path) - if cfg.run_model == 'fixemgan': - labels, train_data = zip(*[(0, line) for line in text_file_iterator(cfg.train_data)]) + if cfg.run_model == "fixemgan": + labels, train_data = zip( + *[(0, line) for line in text_file_iterator(cfg.train_data)] + ) - if cfg.run_model == 'cat_fixemgan': + if cfg.run_model == "cat_fixemgan": labels, train_data = zip( *chain( - *[[(i, line) for line in text_file_iterator(cfg.cat_train_data.format(i))] - for i in range(cfg.k_label)] + *[ + [ + (i, line) + for line in text_file_iterator(cfg.cat_train_data.format(i)) + ] + for i in range(cfg.k_label) + ] ) ) - self.train_data_supplier = DataSupplier(train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch) + self.train_data_supplier = DataSupplier( + train_data, labels, w2v, cfg.batch_size, cfg.batches_per_epoch + ) self.dis = Discriminator(cfg.discriminator_complexity) - self.log.info(f"discriminator total tranable parameters: {number_of_parameters(self.dis.parameters())}") - self.gen = Generator(cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size) - self.log.info(f"generator total tranable parameters: {number_of_parameters(self.gen.parameters())}") + self.log.info( + f"discriminator total tranable parameters: {number_of_parameters(self.dis.parameters())}" + ) + self.gen = Generator( + cfg.generator_complexity, cfg.noise_size, w2v, cfg.w2v_embedding_size + ) + self.log.info( + f"generator total tranable parameters: {number_of_parameters(self.gen.parameters())}" + ) if cfg.CUDA: self.dis = self.dis.cuda() self.gen = self.gen.cuda() - self.G_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA) - self.D_criterion = GANLoss(cfg.loss_type, which_net=None, which_D=None, target_real_label=0.8, target_fake_label=0.2, CUDA=cfg.CUDA) + self.G_criterion = GANLoss( + cfg.loss_type, which_net=None, which_D=None, CUDA=cfg.CUDA + ) + self.D_criterion = GANLoss( + cfg.loss_type, + which_net=None, + which_D=None, + target_real_label=0.8, + target_fake_label=0.2, + CUDA=cfg.CUDA, + ) def build_embedding(self): self.log.info(f"Didn't find embeddings in {cfg.pretrain_embedding_path}") self.log.info("Will train new one, it may take a while...") - sources = list(Path(cfg.texts_pile).glob('*.txt')) + sources = list(Path(cfg.texts_pile).glob("*.txt")) EmbeddingsTrainer(sources, cfg.pretrain_embedding_path).make_embeddings() def generator_train_one_batch(self): @@ -70,7 +94,9 @@ def generator_train_one_batch(self): fakes = self.gen(*noise) real_fake_predicts, label_predicts = self.dis(fakes) - loss = self.G_criterion.G_loss_fixem(real_fake_predicts, label_predicts, noise[1], fakes) + loss = self.G_criterion.G_loss_fixem( + real_fake_predicts, label_predicts, noise[1], fakes + ) loss.backward() self.gen.optimizer.step() @@ -94,18 +120,21 @@ def discriminator_train_one_batch(self, real_vector, labels): # optmizer step self.dis.optimizer.zero_grad() real_fake_predicts, label_predicts = self.dis(text_input_vectors) - loss = self.D_criterion.D_loss_fixem(real_fake_predicts, label_predicts[:this_batch_size], labels) + loss = self.D_criterion.D_loss_fixem( + real_fake_predicts, label_predicts[:this_batch_size], labels + ) loss.backward() self.dis.optimizer.step() real_fake_predicts = real_fake_predicts.clone().detach() - real_fake_predicts = real_fake_predicts.chunk(2) #splitting to realand fake parks + real_fake_predicts = real_fake_predicts.chunk( + 2 + ) # splitting to realand fake parks discriminator_acc = float( - torch.cat(( - real_fake_predicts[0] > 0.5, - real_fake_predicts[1] < 0.5 - )).mean(dtype=float) + torch.cat((real_fake_predicts[0] > 0.5, real_fake_predicts[1] < 0.5)).mean( + dtype=float + ) ) return discriminator_acc @@ -114,20 +143,27 @@ def _run(self): for labels, text_vector in tqdm(self.train_data_supplier, leave=False): if cfg.CUDA: labels, text_vector = labels.cuda(), text_vector.cuda() - discriminator_acc = self.discriminator_train_one_batch(text_vector, labels) + discriminator_acc = self.discriminator_train_one_batch( + text_vector, labels + ) generator_acc = 1 - 2 * (discriminator_acc - 0.5) # run the generator until generator acc not get high enought while self.one_more_batch_for_generator(generator_acc): generator_acc = self.generator_train_one_batch() - if cfg.run_model == 'fixemgan': - print('calculating_metrics') + if cfg.run_model == "fixemgan": + print("calculating_metrics") scores = self.cal_metrics(fmt_str=True) - if cfg.run_model == 'cat_fixemgan': - scores = '\n\n'.join([self.cal_metrics_with_label(label_i=label_i, fmt_str=True) for label_i in range(cfg.k_label)]) - self.log.info(f'epoch: {i}') - self.log.info(f'{scores}') + if cfg.run_model == "cat_fixemgan": + scores = "\n\n".join( + [ + self.cal_metrics_with_label(label_i=label_i, fmt_str=True) + for label_i in range(cfg.k_label) + ] + ) + self.log.info(f"epoch: {i}") + self.log.info(f"{scores}") def one_more_batch_for_generator( self, generator_acc, leave_in_generator_min=0.1, leave_in_generator_max=0.9 @@ -146,8 +182,17 @@ def sample_for_metrics(self): return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): - gen_tokens = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens = self.gen.sample( + cfg.samples_num, 8 * cfg.batch_size, label_i=label_i + ) gen_tokens = [sample.split() for sample in gen_tokens] - gen_tokens_s = self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i) + gen_tokens_s = self.gen.sample( + cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i + ) gen_tokens_s = [sample.split() for sample in gen_tokens_s] - return GenDataIter(gen_tokens), gen_tokens, gen_tokens_s, CatClasDataIter([gen_tokens], label_i) + return ( + GenDataIter(gen_tokens), + gen_tokens, + gen_tokens_s, + CatClasDataIter([gen_tokens], label_i), + ) diff --git a/instructor/real_data/instructor.py b/instructor/real_data/instructor.py index 6a6d9c75..25d641dd 100644 --- a/instructor/real_data/instructor.py +++ b/instructor/real_data/instructor.py @@ -28,9 +28,14 @@ class BasicInstructor: def __init__(self, opt): - self.log = create_logger(__name__, silent=False, to_disk=True, - log_file=cfg.log_filename if cfg.if_test - else [cfg.log_filename, cfg.save_root + 'log.txt']) + self.log = create_logger( + __name__, + silent=False, + to_disk=True, + log_file=cfg.log_filename + if cfg.if_test + else [cfg.log_filename, cfg.save_root + "log.txt"], + ) self.sig = Signal(cfg.signal_file) self.opt = opt self.show_config() @@ -53,14 +58,24 @@ def __init__(self, opt): pass try: - self.train_data_list = [GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label)] - self.test_data_list = [GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) for i in - range(cfg.k_label)] - self.clas_data_list = [GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) for i in - range(cfg.k_label)] - - self.train_samples_list = [self.train_data_list[i].target for i in range(cfg.k_label)] - self.clas_samples_list = [self.clas_data_list[i].target for i in range(cfg.k_label)] + self.train_data_list = [ + GenDataIter(cfg.cat_train_data.format(i)) for i in range(cfg.k_label) + ] + self.test_data_list = [ + GenDataIter(cfg.cat_test_data.format(i), if_test_data=True) + for i in range(cfg.k_label) + ] + self.clas_data_list = [ + GenDataIter(cfg.cat_test_data.format(str(i)), if_test_data=True) + for i in range(cfg.k_label) + ] + + self.train_samples_list = [ + self.train_data_list[i].target for i in range(cfg.k_label) + ] + self.clas_samples_list = [ + self.clas_data_list[i].target for i in range(cfg.k_label) + ] except: pass @@ -74,27 +89,40 @@ def __init__(self, opt): # Metrics # bleu, more-better, changes in range 0.4 - 0.6, will have relatively high weight - self.bleu = BLEU('BLEU', weight=3, gram=3, if_use=cfg.use_bleu) + self.bleu = BLEU("BLEU", weight=3, gram=3, if_use=cfg.use_bleu) # nll-gen, less-better, changes in range 1.5 - 3 will have smaller wight (not in use) - self.nll_gen = NLL('NLL_gen', weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) + self.nll_gen = NLL("NLL_gen", weight=0, if_use=cfg.use_nll_gen, gpu=cfg.CUDA) # nll-div, more-better, changes in range 0.5 - 1.5 will have smaller wight (not in use) - self.nll_div = NLL('NLL_div', weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) + self.nll_div = NLL("NLL_div", weight=0, if_use=cfg.use_nll_div, gpu=cfg.CUDA) # self-bleu, less-better, changes in range 0.7 - 0.9, will have relatively high weight - self.self_bleu = BLEU('Self-BLEU', weight=-3, gram=3, if_use=cfg.use_self_bleu) + self.self_bleu = BLEU("Self-BLEU", weight=-3, gram=3, if_use=cfg.use_self_bleu) # class-acc, more-better, changes in range 0.7 - 1.0, moderate weight self.clas_acc = ACC(weight=1, if_use=cfg.use_clas_acc) # IOC, less-better, changes in range 0.8 - 2.0, smaller weight self.ioc = IOC(weight=-0.3, if_use=cfg.use_ioc, real_text=self.test_data) # nll_oracle, less-better, changes in range -0.1 - 0.6, moderate weight - self.nll_oracle = GPTNLL(weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data) + self.nll_oracle = GPTNLL( + weight=-3, if_use=cfg.use_nll_oracle, real_text=self.test_data + ) # perplexity, less-better, changes in range 3 - 4, moderate weight (not in use) - self.ppl = PPL(self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl) + self.ppl = PPL( + self.train_data, self.test_data, weight=0, n_gram=5, if_use=cfg.use_ppl + ) # dummy, add constant value to overall score self.dummy = Dummy(weight=1, value=5, if_use=True) - self.all_metrics = [self.bleu, self.nll_gen, self.nll_div, self.self_bleu, self.ioc, self.nll_oracle, self.ppl, self.dummy] + self.all_metrics = [ + self.bleu, + self.nll_gen, + self.nll_div, + self.self_bleu, + self.ioc, + self.nll_oracle, + self.ppl, + self.dummy, + ] def _run(self): - print('Nothing to run in Basic Instructor!') + print("Nothing to run in Basic Instructor!") pass def _test(self): @@ -103,11 +131,22 @@ def _test(self): def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: - self.log.info('Load MLE pre-trained generator: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pre-trained generator: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -116,7 +155,7 @@ def init_model(self): def train_gen_epoch(self, model, data_loader, criterion, optimizer): total_loss = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -132,7 +171,7 @@ def train_dis_epoch(self, model, data_loader, criterion, optimizer): total_acc = 0 total_num = 0 for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -168,15 +207,28 @@ def train_classifier(self, epochs): max_acc = 0 best_clas = None for epoch in range(epochs): - c_loss, c_acc = self.train_dis_epoch(self.clas, clas_data.loader, self.clas_criterion, - self.clas_opt) - _, eval_acc = self.eval_dis(self.clas, eval_clas_data.loader, self.clas_criterion) + c_loss, c_acc = self.train_dis_epoch( + self.clas, clas_data.loader, self.clas_criterion, self.clas_opt + ) + _, eval_acc = self.eval_dis( + self.clas, eval_clas_data.loader, self.clas_criterion + ) if eval_acc > max_acc: - best_clas = copy.deepcopy(self.clas.state_dict()) # save the best classifier + best_clas = copy.deepcopy( + self.clas.state_dict() + ) # save the best classifier max_acc = eval_acc - self.log.info('[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f', - epoch, c_loss, c_acc, eval_acc, max_acc) - self.clas.load_state_dict(copy.deepcopy(best_clas)) # Reload the best classifier + self.log.info( + "[PRE-CLAS] epoch %d: c_loss = %.4f, c_acc = %.4f, eval_acc = %.4f, max_eval_acc = %.4f", + epoch, + c_loss, + c_acc, + eval_acc, + max_acc, + ) + self.clas.load_state_dict( + copy.deepcopy(best_clas) + ) # Reload the best classifier @staticmethod def eval_dis(model, data_loader, criterion): @@ -185,7 +237,7 @@ def eval_dis(model, data_loader, criterion): total_num = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -214,24 +266,32 @@ def optimize(opt, loss, model=None, retain_graph=False): opt.step() def show_config(self): - self.log.info(100 * '=') - self.log.info('> training arguments:') + self.log.info(100 * "=") + self.log.info("> training arguments:") for arg in vars(self.opt): - self.log.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) - self.log.info(100 * '=') + self.log.info(">>> {0}: {1}".format(arg, getattr(self.opt, arg))) + self.log.info(100 * "=") def sample_for_metrics(self): eval_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size), + self.idx2word_dict, + ) return gen_data, gen_tokens, gen_tokens_s def sample_for_metrics_with_label(self, label_i): - eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i) + eval_samples = self.gen.sample( + cfg.samples_num, 8 * cfg.batch_size, label_i=label_i + ) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(cfg.small_sample_num, 8 * cfg.batch_size, label_i=label_i), + self.idx2word_dict, + ) clas_data = CatClasDataIter([eval_samples], label_i) return gen_data, gen_tokens, gen_tokens_s, clas_data @@ -243,7 +303,7 @@ def cal_metrics(self, fmt_str=False): with torch.no_grad(): # Prepare data for evaluation gen_data, gen_tokens, gen_tokens_s = self.sample_for_metrics() - print('sampled') + print("sampled") # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) self.nll_gen.reset(self.gen, self.train_data.loader) @@ -253,24 +313,39 @@ def cal_metrics(self, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) - print('all reset') + print("all reset") metrics = {metric.name: metric.get_score() for metric in self.all_metrics} - print('get_score called') - metrics.update({"Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + print("get_score called") + metrics.update( + { + "Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) wandb.log(metrics) if fmt_str: - return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) return [metric.get_score() for metric in self.all_metrics] def cal_metrics_with_label(self, label_i, fmt_str=False): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation - gen_data, gen_tokens, gen_tokens_s, clas_data = self.sample_for_metrics_with_label(label_i) + ( + gen_data, + gen_tokens, + gen_tokens_s, + clas_data, + ) = self.sample_for_metrics_with_label(label_i) # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) + self.bleu.reset( + test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens + ) self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i) self.nll_div.reset(self.gen, gen_data.loader, label_i) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) @@ -279,33 +354,55 @@ def cal_metrics_with_label(self, label_i, fmt_str=False): self.ioc.reset(test_text=gen_tokens) self.nll_oracle.reset(test_text=gen_tokens) - metrics = {f"label {label_i}_{metric.name}": metric.get_score() for metric in self.all_metrics} - metrics.update({f"label {label_i} Overal_score": sum(metric.weight * metric.get_score() for metric in self.all_metrics)}) + metrics = { + f"label {label_i}_{metric.name}": metric.get_score() + for metric in self.all_metrics + } + metrics.update( + { + f"label {label_i} Overal_score": sum( + metric.weight * metric.get_score() for metric in self.all_metrics + ) + } + ) wandb.log(metrics) if fmt_str: - return "\n" + "\n".join([f"{name} = {score}" for name, score in metrics.items()]) + return "\n" + "\n".join( + [f"{name} = {score}" for name, score in metrics.items()] + ) return [metric.get_score() for metric in self.all_metrics] def comb_metrics(self, fmt_str=False): - all_scores = [self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label)] + all_scores = [ + self.cal_metrics_with_label(label_i) for label_i in range(cfg.k_label) + ] if fmt_str: - return ', '.join([ - f'{name} = {[scores[name] for scores in all_scores]}' - for name in all_scores[0] - ]) + return ", ".join( + [ + f"{name} = {[scores[name] for scores in all_scores]}" + for name in all_scores[0] + ] + ) return [scores.values() for scores in all_scores] def _save(self, phase, epoch): """Save model state dict and generator's samples""" - if phase != 'ADV': - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + if phase != "ADV": + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) def update_temperature(self, i, N): - self.gen.temperature.data = torch.Tensor([get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)]) + self.gen.temperature.data = torch.Tensor( + [get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)] + ) if cfg.CUDA: self.gen.temperature.data = self.gen.temperature.data.cuda() diff --git a/instructor/real_data/jsdgan_instructor.py b/instructor/real_data/jsdgan_instructor.py index be9f782f..8144d6a3 100644 --- a/instructor/real_data/jsdgan_instructor.py +++ b/instructor/real_data/jsdgan_instructor.py @@ -20,8 +20,17 @@ def __init__(self, opt): super(JSDGANInstructor, self).__init__(opt) # generator - self.gen = JSDGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = JSDGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -29,8 +38,14 @@ def __init__(self, opt): def init_model(self): if cfg.gen_pretrain: - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path)) - self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device))) + self.log.info( + "Load MLE pretrained generator gen: {}".format(cfg.pretrained_gen_path) + ) + self.gen.load_state_dict( + torch.load( + cfg.pretrained_gen_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.CUDA: self.gen = self.gen.cuda() @@ -38,23 +53,29 @@ def init_model(self): def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") for adv_epoch in range(cfg.ADV_train_epoch): g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss = %.4f, %s' % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss = %.4f, %s" + % (adv_epoch, g_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -66,16 +87,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -87,7 +112,7 @@ def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): for i, data in enumerate(self.train_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() diff --git a/instructor/real_data/leakgan_instructor.py b/instructor/real_data/leakgan_instructor.py index fc02e41a..cecf59eb 100644 --- a/instructor/real_data/leakgan_instructor.py +++ b/instructor/real_data/leakgan_instructor.py @@ -24,9 +24,19 @@ def __init__(self, opt): super(LeakGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, cfg.goal_size, cfg.step_size, cfg.CUDA) - self.dis = LeakGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = LeakGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + cfg.goal_size, + cfg.step_size, + cfg.CUDA, + ) + self.dis = LeakGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # optimizer @@ -39,48 +49,63 @@ def __init__(self, opt): def _run(self): for inter_num in range(cfg.inter_epoch): - self.log.info('>>> Interleaved Round %d...' % inter_num) + self.log.info(">>> Interleaved Round %d..." % inter_num) self.sig.update() # update signal if self.sig.pre_sig: # ===DISCRIMINATOR PRE-TRAINING=== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format( + cfg.pretrained_dis_path + ) + ) # ===GENERATOR MLE TRAINING=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + ) + ) else: - self.log.info('>>> Stop by pre_signal! Skip to adversarial training...') + self.log.info(">>> Stop by pre_signal! Skip to adversarial training...") break # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (str(self.cal_metrics(fmt_str=True)))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (str(self.cal_metrics(fmt_str=True)))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -98,7 +123,7 @@ def pretrain_generator(self, epochs): # ===Train=== for i, data in enumerate(self.train_data.loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if cfg.CUDA: inp, target = inp.cuda(), target.cuda() @@ -111,13 +136,20 @@ def pretrain_generator(self, epochs): # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s' % ( - epoch, pre_mana_loss, pre_work_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s" + % ( + epoch, + pre_mana_loss, + pre_work_loss, + self.cal_metrics(fmt_str=True), + ) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step, current_k=0): @@ -131,13 +163,15 @@ def adv_train_generator(self, g_step, current_k=0): adv_work_loss = 0 for step in range(g_step): with torch.no_grad(): - gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis, - train=True) # !!! train=True, the only place + gen_samples = self.gen.sample( + cfg.batch_size, cfg.batch_size, self.dis, train=True + ) # !!! train=True, the only place inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA) # ===Train=== - rewards = rollout_func.get_reward_leakgan(target, cfg.rollout_num, self.dis, - current_k).cpu() # reward with MC search + rewards = rollout_func.get_reward_leakgan( + target, cfg.rollout_num, self.dis, current_k + ).cpu() # reward with MC search mana_loss, work_loss = self.gen.adversarial_loss(target, rewards, self.dis) # update parameters @@ -145,10 +179,16 @@ def adv_train_generator(self, g_step, current_k=0): adv_mana_loss += mana_loss.data.item() adv_work_loss += work_loss.data.item() # ===Test=== - self.log.info('[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' % ( - adv_mana_loss / g_step, adv_work_loss / g_step, self.cal_metrics(fmt_str=True))) - - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + self.log.info( + "[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s" + % ( + adv_mana_loss / g_step, + adv_work_loss / g_step, + self.cal_metrics(fmt_str=True), + ) + ) + + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -162,12 +202,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) def cal_metrics(self, fmt_str=False): with torch.no_grad(): @@ -175,7 +218,9 @@ def cal_metrics(self, fmt_str=False): eval_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen.sample(200, cfg.batch_size, self.dis), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen.sample(200, cfg.batch_size, self.dis), self.idx2word_dict + ) # Reset metrics self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens) @@ -185,12 +230,22 @@ def cal_metrics(self, fmt_str=False): self.ppl.reset(gen_tokens) if fmt_str: - return ', '.join(['%s = %s' % (metric.name, metric.get_score()) for metric in self.all_metrics]) + return ", ".join( + [ + "%s = %s" % (metric.name, metric.get_score()) + for metric in self.all_metrics + ] + ) else: return [metric.get_score() for metric in self.all_metrics] def _save(self, phase, epoch): - torch.save(self.gen.state_dict(), cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(phase, epoch) + torch.save( + self.gen.state_dict(), + cfg.save_model_root + "gen_{}_{:05d}.pt".format(phase, epoch), + ) + save_sample_path = cfg.save_samples_root + "samples_{}_{:05d}.txt".format( + phase, epoch + ) samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis) write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) diff --git a/instructor/real_data/maligan_instructor.py b/instructor/real_data/maligan_instructor.py index a27356f8..a0b0eeb2 100644 --- a/instructor/real_data/maligan_instructor.py +++ b/instructor/real_data/maligan_instructor.py @@ -25,9 +25,17 @@ def __init__(self, opt): super(MaliGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = MaliGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = MaliGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = MaliGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = MaliGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -39,40 +47,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -84,16 +101,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): """ total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = self.get_mali_reward(target) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -129,12 +155,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/instructor/real_data/relgan_instructor.py b/instructor/real_data/relgan_instructor.py index 391b961b..22f338d7 100644 --- a/instructor/real_data/relgan_instructor.py +++ b/instructor/real_data/relgan_instructor.py @@ -24,10 +24,25 @@ def __init__(self, opt): super(RelGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim, - cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) - self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx, - gpu=cfg.CUDA) + self.gen = RelGAN_G( + cfg.mem_slots, + cfg.num_heads, + cfg.head_size, + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = RelGAN_D( + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer @@ -38,39 +53,50 @@ def __init__(self, opt): def _run(self): # ===PRE-TRAINING (GENERATOR)=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pretrain_generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pretrain_generator: {}".format(cfg.pretrained_gen_path)) # # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') + self.log.info("Starting Adversarial Training...") progress = tqdm(range(cfg.ADV_train_epoch)) for adv_epoch in progress: self.sig.update() if self.sig.adv_sig: g_loss = self.adv_train_generator(cfg.ADV_g_step) # Generator d_loss = self.adv_train_discriminator(cfg.ADV_d_step) # Discriminator - self.update_temperature(adv_epoch, cfg.ADV_train_epoch) # update temperature + self.update_temperature( + adv_epoch, cfg.ADV_train_epoch + ) # update temperature progress.set_description( - 'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature)) + "g_loss: %.4f, d_loss: %.4f, temperature: %.4f" + % (g_loss, d_loss, self.gen.temperature) + ) # TEST - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: - self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % ( - adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True))) + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): + self.log.info( + "[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s" + % (adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) progress.close() break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -83,23 +109,27 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: # ===Train=== - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): total_loss = 0 for step in range(g_step): - real_samples = self.train_data.random_batch()['target'] + real_samples = self.train_data.random_batch()["target"] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -118,7 +148,7 @@ def adv_train_generator(self, g_step): def adv_train_discriminator(self, d_step): total_loss = 0 for step in range(d_step): - real_samples = self.train_data.random_batch()['target'] + real_samples = self.train_data.random_batch()["target"] gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True) if cfg.CUDA: real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda() @@ -135,7 +165,9 @@ def adv_train_discriminator(self, d_step): return total_loss / d_step if d_step != 0 else 0 def update_temperature(self, i, N): - self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt) + self.gen.temperature = get_fixed_temperature( + cfg.temperature, i, N, cfg.temp_adpt + ) @staticmethod def optimize(opt, loss, model=None, retain_graph=False): diff --git a/instructor/real_data/sentigan_instructor.py b/instructor/real_data/sentigan_instructor.py index 952df9da..052b8bc0 100644 --- a/instructor/real_data/sentigan_instructor.py +++ b/instructor/real_data/sentigan_instructor.py @@ -25,15 +25,39 @@ def __init__(self, opt): super(SentiGANInstructor, self).__init__(opt) # generator, discriminator - self.gen_list = [SentiGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) for _ in range(cfg.k_label)] - self.dis = SentiGAN_D(cfg.k_label, cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) - self.clas = SentiGAN_C(cfg.k_label, cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.extend_vocab_size, - cfg.padding_idx, gpu=cfg.CUDA) + self.gen_list = [ + SentiGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + for _ in range(cfg.k_label) + ] + self.dis = SentiGAN_D( + cfg.k_label, + cfg.dis_embed_dim, + cfg.vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.clas = SentiGAN_C( + cfg.k_label, + cfg.dis_embed_dim, + cfg.max_seq_len, + cfg.num_rep, + cfg.extend_vocab_size, + cfg.padding_idx, + gpu=cfg.CUDA, + ) self.init_model() # Optimizer - self.gen_opt_list = [optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list] + self.gen_opt_list = [ + optim.Adam(gen.parameters(), lr=cfg.gen_lr) for gen in self.gen_list + ] self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr) self.clas_opt = optim.Adam(self.clas.parameters(), lr=cfg.clas_lr) @@ -43,16 +67,35 @@ def __init__(self, opt): def init_model(self): if cfg.dis_pretrain: self.log.info( - 'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path)) - self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path, map_location='cuda:{}'.format(cfg.device))) + "Load pretrained discriminator: {}".format(cfg.pretrained_dis_path) + ) + self.dis.load_state_dict( + torch.load( + cfg.pretrained_dis_path, map_location="cuda:{}".format(cfg.device) + ) + ) if cfg.gen_pretrain: for i in range(cfg.k_label): - self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + self.log.info( + "Load MLE pretrained generator gen: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) self.gen_list[i].load_state_dict( - torch.load(cfg.pretrained_gen_path + '%d' % i, map_location='cuda:{}'.format(cfg.device))) + torch.load( + cfg.pretrained_gen_path + "%d" % i, + map_location="cuda:{}".format(cfg.device), + ) + ) if cfg.clas_pretrain: - self.log.info('Load pretrained classifier: {}'.format(cfg.pretrained_clas_path)) - self.clas.load_state_dict(torch.load(cfg.pretrained_clas_path, map_location='cuda:%d' % cfg.device)) + self.log.info( + "Load pretrained classifier: {}".format(cfg.pretrained_clas_path) + ) + self.clas.load_state_dict( + torch.load( + cfg.pretrained_clas_path, map_location="cuda:%d" % cfg.device + ) + ) if cfg.CUDA: for i in range(cfg.k_label): @@ -63,46 +106,62 @@ def init_model(self): def _run(self): # ===Pre-train Classifier with real data=== if cfg.use_clas_acc: - self.log.info('Start training Classifier...') + self.log.info("Start training Classifier...") self.train_classifier(cfg.PRE_clas_epoch) # ===PRE-TRAIN GENERATOR=== if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: for i in range(cfg.k_label): - torch.save(self.gen_list[i].state_dict(), cfg.pretrained_gen_path + '%d' % i) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path + '%d' % i)) + torch.save( + self.gen_list[i].state_dict(), + cfg.pretrained_gen_path + "%d" % i, + ) + print( + "Save pre-trained generator: {}".format( + cfg.pretrained_gen_path + "%d" % i + ) + ) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s', self.comb_metrics(fmt_str=True)) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s", self.comb_metrics(fmt_str=True)) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -115,18 +174,24 @@ def pretrain_generator(self, epochs): self.sig.update() if self.sig.pre_sig: for i in range(cfg.k_label): - pre_loss = self.train_gen_epoch(self.gen_list[i], self.train_data_list[i].loader, - self.mle_criterion, self.gen_opt_list[i]) + pre_loss = self.train_gen_epoch( + self.gen_list[i], + self.train_data_list[i].loader, + self.mle_criterion, + self.gen_opt_list[i], + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: if i == cfg.k_label - 1: - self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % ( - epoch, pre_loss, self.comb_metrics(fmt_str=True))) + self.log.info( + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.comb_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -138,18 +203,23 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), + gpu=cfg.CUDA, + ) # ===Train=== - rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis, current_k=i) + rewards = rollout_func.get_reward( + target, cfg.rollout_num, self.dis, current_k=i + ) adv_loss = self.gen_list[i].batchPGLoss(inp, target, rewards) self.optimize(self.gen_opt_list[i], adv_loss) total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True)) + self.log.info("[ADV-GEN]: %s", self.comb_metrics(fmt_str=True)) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -163,37 +233,52 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): fake_samples = [] for i in range(cfg.k_label): real_samples.append(self.train_samples_list[i]) - fake_samples.append(self.gen_list[i].sample(cfg.samples_num // cfg.k_label, 8 * cfg.batch_size)) + fake_samples.append( + self.gen_list[i].sample( + cfg.samples_num // cfg.k_label, 8 * cfg.batch_size + ) + ) dis_samples_list = [torch.cat(fake_samples, dim=0)] + real_samples dis_data = CatClasDataIter(dis_samples_list) for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f" + % (phase, step, d_loss, train_acc) + ) - if cfg.if_save and not cfg.if_test and phase == 'MLE': + if cfg.if_save and not cfg.if_test and phase == "MLE": torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) def cal_metrics_with_label(self, label_i): - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" with torch.no_grad(): # Prepare data for evaluation - eval_samples = self.gen_list[label_i].sample(cfg.samples_num, 8 * cfg.batch_size) + eval_samples = self.gen_list[label_i].sample( + cfg.samples_num, 8 * cfg.batch_size + ) gen_data = GenDataIter(eval_samples) gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict) - gen_tokens_s = tensor_to_tokens(self.gen_list[label_i].sample(200, 200), self.idx2word_dict) + gen_tokens_s = tensor_to_tokens( + self.gen_list[label_i].sample(200, 200), self.idx2word_dict + ) clas_data = CatClasDataIter([eval_samples], label_i) # Reset metrics - self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens) - self.nll_gen.reset(self.gen_list[label_i], self.train_data_list[label_i].loader) + self.bleu.reset( + test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens + ) + self.nll_gen.reset( + self.gen_list[label_i], self.train_data_list[label_i].loader + ) self.nll_div.reset(self.gen_list[label_i], gen_data.loader) self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens) self.clas_acc.reset(self.clas, clas_data.loader) @@ -204,9 +289,16 @@ def cal_metrics_with_label(self, label_i): def _save(self, phase, epoch): """Save model state dict and generator's samples""" for i in range(cfg.k_label): - if phase != 'ADV': - torch.save(self.gen_list[i].state_dict(), - cfg.save_model_root + 'gen{}_{}_{:05d}.pt'.format(i, phase, epoch)) - save_sample_path = cfg.save_samples_root + 'samples_d{}_{}_{:05d}.txt'.format(i, phase, epoch) + if phase != "ADV": + torch.save( + self.gen_list[i].state_dict(), + cfg.save_model_root + "gen{}_{}_{:05d}.pt".format(i, phase, epoch), + ) + save_sample_path = ( + cfg.save_samples_root + + "samples_d{}_{}_{:05d}.txt".format(i, phase, epoch) + ) samples = self.gen_list[i].sample(cfg.batch_size, cfg.batch_size) - write_tokens(save_sample_path, tensor_to_tokens(samples, self.idx2word_dict)) + write_tokens( + save_sample_path, tensor_to_tokens(samples, self.idx2word_dict) + ) diff --git a/instructor/real_data/seqgan_instructor.py b/instructor/real_data/seqgan_instructor.py index dccb4a2f..2241a4d0 100644 --- a/instructor/real_data/seqgan_instructor.py +++ b/instructor/real_data/seqgan_instructor.py @@ -23,9 +23,17 @@ def __init__(self, opt): super(SeqGANInstructor, self).__init__(opt) # generator, discriminator - self.gen = SeqGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, - cfg.padding_idx, gpu=cfg.CUDA) - self.dis = SeqGAN_D(cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA) + self.gen = SeqGAN_G( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) + self.dis = SeqGAN_D( + cfg.dis_embed_dim, cfg.vocab_size, cfg.padding_idx, gpu=cfg.CUDA + ) self.init_model() # Optimizer @@ -37,40 +45,49 @@ def _run(self): # ===PRE-TRAINING=== # TRAIN GENERATOR if not cfg.gen_pretrain: - self.log.info('Starting Generator MLE Training...') + self.log.info("Starting Generator MLE Training...") self.pretrain_generator(cfg.MLE_train_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.gen.state_dict(), cfg.pretrained_gen_path) - print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path)) + print("Save pre-trained generator: {}".format(cfg.pretrained_gen_path)) # ===TRAIN DISCRIMINATOR==== if not cfg.dis_pretrain: - self.log.info('Starting Discriminator Training...') + self.log.info("Starting Discriminator Training...") self.train_discriminator(cfg.d_step, cfg.d_epoch) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) - print('Save pre-trained discriminator: {}'.format(cfg.pretrained_dis_path)) + print( + "Save pre-trained discriminator: {}".format(cfg.pretrained_dis_path) + ) # ===ADVERSARIAL TRAINING=== - self.log.info('Starting Adversarial Training...') - self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True))) + self.log.info("Starting Adversarial Training...") + self.log.info("Initial generator: %s" % (self.cal_metrics(fmt_str=True))) for adv_epoch in range(cfg.ADV_train_epoch): - self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch) + self.log.info("-----\nADV EPOCH %d\n-----" % adv_epoch) self.sig.update() if self.sig.adv_sig: self.adv_train_generator(cfg.ADV_g_step) # Generator - self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch, 'ADV') # Discriminator - - if adv_epoch % cfg.adv_log_step == 0 or adv_epoch == cfg.ADV_train_epoch - 1: + self.train_discriminator( + cfg.ADV_d_step, cfg.ADV_d_epoch, "ADV" + ) # Discriminator + + if ( + adv_epoch % cfg.adv_log_step == 0 + or adv_epoch == cfg.ADV_train_epoch - 1 + ): if cfg.if_save and not cfg.if_test: - self._save('ADV', adv_epoch) + self._save("ADV", adv_epoch) else: - self.log.info('>>> Stop by adv_signal! Finishing adversarial training...') + self.log.info( + ">>> Stop by adv_signal! Finishing adversarial training..." + ) break def _test(self): - print('>>> Begin test...') + print(">>> Begin test...") self._run() pass @@ -82,16 +99,20 @@ def pretrain_generator(self, epochs): for epoch in range(epochs): self.sig.update() if self.sig.pre_sig: - pre_loss = self.train_gen_epoch(self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt) + pre_loss = self.train_gen_epoch( + self.gen, self.train_data.loader, self.mle_criterion, self.gen_opt + ) # ===Test=== if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1: self.log.info( - '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True))) + "[MLE-GEN] epoch %d : pre_loss = %.4f, %s" + % (epoch, pre_loss, self.cal_metrics(fmt_str=True)) + ) if cfg.if_save and not cfg.if_test: - self._save('MLE', epoch) + self._save("MLE", epoch) else: - self.log.info('>>> Stop by pre signal, skip to adversarial training...') + self.log.info(">>> Stop by pre signal, skip to adversarial training...") break def adv_train_generator(self, g_step): @@ -102,7 +123,9 @@ def adv_train_generator(self, g_step): rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA) total_g_loss = 0 for step in range(g_step): - inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA) + inp, target = GenDataIter.prepare( + self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA + ) # ===Train=== rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis) @@ -111,9 +134,12 @@ def adv_train_generator(self, g_step): total_g_loss += adv_loss.item() # ===Test=== - self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) + self.log.info( + "[ADV-GEN]: g_loss = %.4f, %s" + % (total_g_loss, self.cal_metrics(fmt_str=True)) + ) - def train_discriminator(self, d_step, d_epoch, phase='MLE'): + def train_discriminator(self, d_step, d_epoch, phase="MLE"): """ Training the discriminator on real_data_samples (positive) and generated samples from gen (negative). Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch. @@ -128,12 +154,15 @@ def train_discriminator(self, d_step, d_epoch, phase='MLE'): for epoch in range(d_epoch): # ===Train=== - d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion, - self.dis_opt) + d_loss, train_acc = self.train_dis_epoch( + self.dis, dis_data.loader, self.dis_criterion, self.dis_opt + ) # ===Test=== - self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % ( - phase, step, d_loss, train_acc)) + self.log.info( + "[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f," + % (phase, step, d_loss, train_acc) + ) if cfg.if_save and not cfg.if_test: torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) diff --git a/main.py b/main.py index 5534b287..b882f37f 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ import yaml import argparse + # import torch import numpy as np import wandb @@ -22,118 +23,124 @@ def program_config(parser): # Program - parser.add_argument('--if_test', default=cfg.if_test, type=int) - parser.add_argument('--run_model', default=cfg.run_model, type=str) - parser.add_argument('--k_label', default=cfg.k_label, type=int) - parser.add_argument('--dataset', default=cfg.dataset, type=str) - parser.add_argument('--model_type', default=cfg.model_type, type=str) - parser.add_argument('--loss_type', default=cfg.loss_type, type=str) - parser.add_argument('--mu_type', default=cfg.mu_type, type=str) - parser.add_argument('--eval_type', default=cfg.eval_type, type=str) - parser.add_argument('--d_type', default=cfg.d_type, type=str) - parser.add_argument('--if_real_data', default=cfg.if_real_data, type=int) - parser.add_argument('--cuda', default=cfg.CUDA, type=int) - parser.add_argument('--device', default=cfg.device, type=int) - parser.add_argument('--devices', default=cfg.devices, type=str) - parser.add_argument('--shuffle', default=cfg.data_shuffle, type=int) - parser.add_argument('--gen_init', default=cfg.gen_init, type=str) - parser.add_argument('--dis_init', default=cfg.dis_init, type=str) + parser.add_argument("--if_test", default=cfg.if_test, type=int) + parser.add_argument("--run_model", default=cfg.run_model, type=str) + parser.add_argument("--k_label", default=cfg.k_label, type=int) + parser.add_argument("--dataset", default=cfg.dataset, type=str) + parser.add_argument("--model_type", default=cfg.model_type, type=str) + parser.add_argument("--loss_type", default=cfg.loss_type, type=str) + parser.add_argument("--mu_type", default=cfg.mu_type, type=str) + parser.add_argument("--eval_type", default=cfg.eval_type, type=str) + parser.add_argument("--d_type", default=cfg.d_type, type=str) + parser.add_argument("--if_real_data", default=cfg.if_real_data, type=int) + parser.add_argument("--cuda", default=cfg.CUDA, type=int) + parser.add_argument("--device", default=cfg.device, type=int) + parser.add_argument("--devices", default=cfg.devices, type=str) + parser.add_argument("--shuffle", default=cfg.data_shuffle, type=int) + parser.add_argument("--gen_init", default=cfg.gen_init, type=str) + parser.add_argument("--dis_init", default=cfg.dis_init, type=str) # CatGAN - parser.add_argument('--n_parent', default=cfg.n_parent, type=int) - parser.add_argument('--eval_b_num', default=cfg.eval_b_num, type=int) - parser.add_argument('--lambda_fq', default=cfg.lambda_fq, type=float) - parser.add_argument('--lambda_fd', default=cfg.lambda_fd, type=float) - parser.add_argument('--d_out_mean', default=cfg.d_out_mean, type=int) - parser.add_argument('--freeze_dis', default=cfg.freeze_dis, type=int) - parser.add_argument('--freeze_clas', default=cfg.freeze_clas, type=int) - parser.add_argument('--use_all_real_fake', default=cfg.use_all_real_fake, type=int) - parser.add_argument('--use_population', default=cfg.use_population, type=int) - parser.add_argument('--batches_per_epoch', default=cfg.batches_per_epoch, type=int) - parser.add_argument('--noise_size', default=cfg.noise_size, type=int) - parser.add_argument('--max_epochs', default=cfg.max_epochs, type=int) - parser.add_argument('--target_len', default=cfg.target_len, type=int) + parser.add_argument("--n_parent", default=cfg.n_parent, type=int) + parser.add_argument("--eval_b_num", default=cfg.eval_b_num, type=int) + parser.add_argument("--lambda_fq", default=cfg.lambda_fq, type=float) + parser.add_argument("--lambda_fd", default=cfg.lambda_fd, type=float) + parser.add_argument("--d_out_mean", default=cfg.d_out_mean, type=int) + parser.add_argument("--freeze_dis", default=cfg.freeze_dis, type=int) + parser.add_argument("--freeze_clas", default=cfg.freeze_clas, type=int) + parser.add_argument("--use_all_real_fake", default=cfg.use_all_real_fake, type=int) + parser.add_argument("--use_population", default=cfg.use_population, type=int) + parser.add_argument("--batches_per_epoch", default=cfg.batches_per_epoch, type=int) + parser.add_argument("--noise_size", default=cfg.noise_size, type=int) + parser.add_argument("--max_epochs", default=cfg.max_epochs, type=int) + parser.add_argument("--target_len", default=cfg.target_len, type=int) # Basic Train - parser.add_argument('--samples_num', default=cfg.samples_num, type=int) - parser.add_argument('--vocab_size', default=cfg.vocab_size, type=int) - parser.add_argument('--mle_epoch', default=cfg.MLE_train_epoch, type=int) - parser.add_argument('--clas_pre_epoch', default=cfg.PRE_clas_epoch, type=int) - parser.add_argument('--adv_epoch', default=cfg.ADV_train_epoch, type=int) - parser.add_argument('--inter_epoch', default=cfg.inter_epoch, type=int) - parser.add_argument('--batch_size', default=cfg.batch_size, type=int) - parser.add_argument('--max_seq_len', default=cfg.max_seq_len, type=int) - parser.add_argument('--start_letter', default=cfg.start_letter, type=int) - parser.add_argument('--padding_idx', default=cfg.padding_idx, type=int) - parser.add_argument('--gen_lr', default=cfg.gen_lr, type=float) - parser.add_argument('--gen_adv_lr', default=cfg.gen_adv_lr, type=float) - parser.add_argument('--dis_lr', default=cfg.dis_lr, type=float) - parser.add_argument('--clip_norm', default=cfg.clip_norm, type=float) - parser.add_argument('--pre_log_step', default=cfg.pre_log_step, type=int) - parser.add_argument('--adv_log_step', default=cfg.adv_log_step, type=int) - parser.add_argument('--train_data', default=cfg.train_data, type=str) - parser.add_argument('--test_data', default=cfg.test_data, type=str) - parser.add_argument('--temp_adpt', default=cfg.temp_adpt, type=str) - parser.add_argument('--evo_temp_step', default=cfg.evo_temp_step, type=int) - parser.add_argument('--temperature', default=cfg.temperature, type=int) - parser.add_argument('--ora_pretrain', default=cfg.oracle_pretrain, type=int) - parser.add_argument('--gen_pretrain', default=cfg.gen_pretrain, type=int) - parser.add_argument('--dis_pretrain', default=cfg.dis_pretrain, type=int) + parser.add_argument("--samples_num", default=cfg.samples_num, type=int) + parser.add_argument("--vocab_size", default=cfg.vocab_size, type=int) + parser.add_argument("--mle_epoch", default=cfg.MLE_train_epoch, type=int) + parser.add_argument("--clas_pre_epoch", default=cfg.PRE_clas_epoch, type=int) + parser.add_argument("--adv_epoch", default=cfg.ADV_train_epoch, type=int) + parser.add_argument("--inter_epoch", default=cfg.inter_epoch, type=int) + parser.add_argument("--batch_size", default=cfg.batch_size, type=int) + parser.add_argument("--max_seq_len", default=cfg.max_seq_len, type=int) + parser.add_argument("--start_letter", default=cfg.start_letter, type=int) + parser.add_argument("--padding_idx", default=cfg.padding_idx, type=int) + parser.add_argument("--gen_lr", default=cfg.gen_lr, type=float) + parser.add_argument("--gen_adv_lr", default=cfg.gen_adv_lr, type=float) + parser.add_argument("--dis_lr", default=cfg.dis_lr, type=float) + parser.add_argument("--clip_norm", default=cfg.clip_norm, type=float) + parser.add_argument("--pre_log_step", default=cfg.pre_log_step, type=int) + parser.add_argument("--adv_log_step", default=cfg.adv_log_step, type=int) + parser.add_argument("--train_data", default=cfg.train_data, type=str) + parser.add_argument("--test_data", default=cfg.test_data, type=str) + parser.add_argument("--temp_adpt", default=cfg.temp_adpt, type=str) + parser.add_argument("--evo_temp_step", default=cfg.evo_temp_step, type=int) + parser.add_argument("--temperature", default=cfg.temperature, type=int) + parser.add_argument("--ora_pretrain", default=cfg.oracle_pretrain, type=int) + parser.add_argument("--gen_pretrain", default=cfg.gen_pretrain, type=int) + parser.add_argument("--dis_pretrain", default=cfg.dis_pretrain, type=int) # Generator - parser.add_argument('--adv_g_step', default=cfg.ADV_g_step, type=int) - parser.add_argument('--rollout_num', default=cfg.rollout_num, type=int) - parser.add_argument('--gen_embed_dim', default=cfg.gen_embed_dim, type=int) - parser.add_argument('--gen_hidden_dim', default=cfg.gen_hidden_dim, type=int) - parser.add_argument('--goal_size', default=cfg.goal_size, type=int) - parser.add_argument('--step_size', default=cfg.step_size, type=int) - parser.add_argument('--mem_slots', default=cfg.mem_slots, type=int) - parser.add_argument('--num_heads', default=cfg.num_heads, type=int) - parser.add_argument('--head_size', default=cfg.head_size, type=int) - parser.add_argument('--generator_complexity', default=cfg.generator_complexity, type=int) + parser.add_argument("--adv_g_step", default=cfg.ADV_g_step, type=int) + parser.add_argument("--rollout_num", default=cfg.rollout_num, type=int) + parser.add_argument("--gen_embed_dim", default=cfg.gen_embed_dim, type=int) + parser.add_argument("--gen_hidden_dim", default=cfg.gen_hidden_dim, type=int) + parser.add_argument("--goal_size", default=cfg.goal_size, type=int) + parser.add_argument("--step_size", default=cfg.step_size, type=int) + parser.add_argument("--mem_slots", default=cfg.mem_slots, type=int) + parser.add_argument("--num_heads", default=cfg.num_heads, type=int) + parser.add_argument("--head_size", default=cfg.head_size, type=int) + parser.add_argument( + "--generator_complexity", default=cfg.generator_complexity, type=int + ) # Discriminator - parser.add_argument('--d_step', default=cfg.d_step, type=int) - parser.add_argument('--d_epoch', default=cfg.d_epoch, type=int) - parser.add_argument('--adv_d_step', default=cfg.ADV_d_step, type=int) - parser.add_argument('--adv_d_epoch', default=cfg.ADV_d_epoch, type=int) - parser.add_argument('--dis_embed_dim', default=cfg.dis_embed_dim, type=int) - parser.add_argument('--dis_hidden_dim', default=cfg.dis_hidden_dim, type=int) - parser.add_argument('--num_rep', default=cfg.num_rep, type=int) - parser.add_argument('--discriminator_complexity', default=cfg.discriminator_complexity, type=int) + parser.add_argument("--d_step", default=cfg.d_step, type=int) + parser.add_argument("--d_epoch", default=cfg.d_epoch, type=int) + parser.add_argument("--adv_d_step", default=cfg.ADV_d_step, type=int) + parser.add_argument("--adv_d_epoch", default=cfg.ADV_d_epoch, type=int) + parser.add_argument("--dis_embed_dim", default=cfg.dis_embed_dim, type=int) + parser.add_argument("--dis_hidden_dim", default=cfg.dis_hidden_dim, type=int) + parser.add_argument("--num_rep", default=cfg.num_rep, type=int) + parser.add_argument( + "--discriminator_complexity", default=cfg.discriminator_complexity, type=int + ) # W2V embeddings - parser.add_argument('--w2v_embedding_size', default=cfg.w2v_embedding_size, type=int) - parser.add_argument('--w2v_window', default=cfg.w2v_window, type=int) - parser.add_argument('--w2v_min_count', default=cfg.w2v_min_count, type=int) - parser.add_argument('--w2v_workers', default=cfg.w2v_workers, type=int) - parser.add_argument('--w2v_samples_num', default=cfg.w2v_samples_num, type=int) + parser.add_argument( + "--w2v_embedding_size", default=cfg.w2v_embedding_size, type=int + ) + parser.add_argument("--w2v_window", default=cfg.w2v_window, type=int) + parser.add_argument("--w2v_min_count", default=cfg.w2v_min_count, type=int) + parser.add_argument("--w2v_workers", default=cfg.w2v_workers, type=int) + parser.add_argument("--w2v_samples_num", default=cfg.w2v_samples_num, type=int) # Metrics - parser.add_argument('--use_nll_oracle', default=cfg.use_nll_oracle, type=int) - parser.add_argument('--use_nll_gen', default=cfg.use_nll_gen, type=int) - parser.add_argument('--use_nll_div', default=cfg.use_nll_div, type=int) - parser.add_argument('--use_bleu', default=cfg.use_bleu, type=int) - parser.add_argument('--use_self_bleu', default=cfg.use_self_bleu, type=int) - parser.add_argument('--use_clas_acc', default=cfg.use_clas_acc, type=int) - parser.add_argument('--use_ppl', default=cfg.use_ppl, type=int) + parser.add_argument("--use_nll_oracle", default=cfg.use_nll_oracle, type=int) + parser.add_argument("--use_nll_gen", default=cfg.use_nll_gen, type=int) + parser.add_argument("--use_nll_div", default=cfg.use_nll_div, type=int) + parser.add_argument("--use_bleu", default=cfg.use_bleu, type=int) + parser.add_argument("--use_self_bleu", default=cfg.use_self_bleu, type=int) + parser.add_argument("--use_clas_acc", default=cfg.use_clas_acc, type=int) + parser.add_argument("--use_ppl", default=cfg.use_ppl, type=int) # Log - parser.add_argument('--log_file', default=cfg.log_filename, type=str) - parser.add_argument('--save_root', default=cfg.save_root, type=str) - parser.add_argument('--signal_file', default=cfg.signal_file, type=str) - parser.add_argument('--tips', default=cfg.tips, type=str) + parser.add_argument("--log_file", default=cfg.log_filename, type=str) + parser.add_argument("--save_root", default=cfg.save_root, type=str) + parser.add_argument("--signal_file", default=cfg.signal_file, type=str) + parser.add_argument("--tips", default=cfg.tips, type=str) # Loss coefficients - parser.add_argument('--real_fake_coeff', default=1.0, type=float) - parser.add_argument('--labels_coeff', default=1.0, type=float) - parser.add_argument('--diversity_coeff', default=1.0, type=float) + parser.add_argument("--real_fake_coeff", default=1.0, type=float) + parser.add_argument("--labels_coeff", default=1.0, type=float) + parser.add_argument("--diversity_coeff", default=1.0, type=float) return parser # MAIN -if __name__ == '__main__': - #seed everything +if __name__ == "__main__": + # seed everything # torch.manual_seed(0) random.seed(0) np.random.seed(0) @@ -144,8 +151,12 @@ def program_config(parser): opt = parser.parse_args() if opt.if_real_data: - opt.max_seq_len, opt.vocab_size = text_process('dataset/' + opt.dataset + '.txt') - cfg.extend_vocab_size = len(load_test_dict(opt.dataset)[0]) # init classifier vocab_size + opt.max_seq_len, opt.vocab_size = text_process( + "dataset/" + opt.dataset + ".txt" + ) + cfg.extend_vocab_size = len( + load_test_dict(opt.dataset)[0] + ) # init classifier vocab_size cfg.init_param(opt) opt.save_root = cfg.save_root opt.train_data = cfg.train_data @@ -181,31 +192,29 @@ def program_config(parser): from instructor.oracle_data.fixem_instructor import FixemGANInstructor instruction_dict = { - 'seqgan': SeqGANInstructor, - 'leakgan': LeakGANInstructor, - 'maligan': MaliGANInstructor, - 'jsdgan': JSDGANInstructor, - 'dpgan': DPGANInstructor, - 'relgan': RelGANInstructor, - 'sentigan': SentiGANInstructor, - 'evogan': EvoGANInstructor, - 'catgan': CatGANInstructor, - 'dgsan': DGSANInstructor, - 'cot': CoTInstructor, - 'fixemgan': FixemGANInstructor, - 'cat_fixemgan': FixemGANInstructor + "seqgan": SeqGANInstructor, + "leakgan": LeakGANInstructor, + "maligan": MaliGANInstructor, + "jsdgan": JSDGANInstructor, + "dpgan": DPGANInstructor, + "relgan": RelGANInstructor, + "sentigan": SentiGANInstructor, + "evogan": EvoGANInstructor, + "catgan": CatGANInstructor, + "dgsan": DGSANInstructor, + "cot": CoTInstructor, + "fixemgan": FixemGANInstructor, + "cat_fixemgan": FixemGANInstructor, } - # Example sweep configuration - with open('sweep.yml') as sweep_yml: + with open("sweep.yml") as sweep_yml: sweep_configuration = yaml.safe_load(sweep_yml) - print('sweep_configuration', sweep_configuration) + print("sweep_configuration", sweep_configuration) sweep_id = wandb.sweep(sweep=sweep_configuration, project="TorchGAN-fixem") # sweep_id = "7g6po2bd" - print('sweep_id', sweep_id) - + print("sweep_id", sweep_id) def full_train_run(opt): inst = instruction_dict[cfg.run_model](opt) @@ -214,7 +223,7 @@ def full_train_run(opt): def function_for_parameters_sweep(): run = wandb.init() # Initialize a new wandb run config = run.config # Get the config dictionary for the current run - print('config', config) + print("config", config) # Update 'opt' with the hyperparameters from 'config' for name, value in config.items(): @@ -222,5 +231,4 @@ def function_for_parameters_sweep(): full_train_run(opt) run.finish() # Make sure to finish the run - wandb.agent(sweep_id=sweep_id, function=function_for_parameters_sweep) diff --git a/metrics/bleu.py b/metrics/bleu.py index d0ccc290..a8e0e6fe 100644 --- a/metrics/bleu.py +++ b/metrics/bleu.py @@ -24,15 +24,27 @@ class BLEU(Metrics): :param is_fast: Fast mode :param given_gram: Calculate specific n-gram BLEU score """ - def __init__(self, name=None, weight=1, test_text=None, real_text=None, gram=3, portion=1, if_use=False): - assert type(gram) == int or type(gram) == list, 'Gram format error!' - super(BLEU, self).__init__('%s-%s' % (name, gram), weight, if_use) + + def __init__( + self, + name=None, + weight=1, + test_text=None, + real_text=None, + gram=3, + portion=1, + if_use=False, + ): + assert type(gram) == int or type(gram) == list, "Gram format error!" + super(BLEU, self).__init__("%s-%s" % (name, gram), weight, if_use) self.if_use = if_use self.test_text = test_text self.real_text = real_text self.gram = gram if type(gram) == int else gram - self.sample_size = 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 + self.sample_size = ( + 200 # BLEU scores remain nearly unchanged for self.sample_size >= 200 + ) self.portion = portion # how many portions to use in the evaluation, default to use the whole test dataset def _reset(self, test_text=None, real_text=None): @@ -45,11 +57,11 @@ def get_reference(self): # In-place shuffle random.shuffle(reference) len_ref = len(reference) - reference = reference[:int(self.portion * len_ref)] + reference = reference[: int(self.portion * len_ref)] return reference def get_bleu(self, given_gram=None): - if type(self.gram) == int: # for single gram + if type(self.gram) == int: # for single gram return self.get_blue_for_single_gram(self.gram) # for multiple gram all_bleu = [] @@ -60,18 +72,22 @@ def get_bleu(self, given_gram=None): def get_blue_for_single_gram(self, ngram): bleu = list() reference = self.get_reference() - weight = tuple((1. / ngram for _ in range(ngram))) - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): + weight = tuple((1.0 / ngram for _ in range(ngram))) + for idx, hypothesis in enumerate(self.test_text[: self.sample_size]): bleu.append(self.cal_bleu(reference, hypothesis, weight)) return round(sum(bleu) / len(bleu), 3) @staticmethod def cal_bleu(reference, hypothesis, weight): - return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight, - smoothing_function=SmoothingFunction().method1) + return nltk.translate.bleu_score.sentence_bleu( + reference, + hypothesis, + weight, + smoothing_function=SmoothingFunction().method1, + ) def calculate_metric(self): - if type(self.gram) == int: # for single gram + if type(self.gram) == int: # for single gram return self.get_blue_for_single_gram(self.gram) # for multiple gram reference = self.get_reference() @@ -81,11 +97,13 @@ def calculate_metric(self): return all_bleu def get_bleu_parallel(self, ngram, reference): - weight = tuple((1. / ngram for _ in range(ngram))) + weight = tuple((1.0 / ngram for _ in range(ngram))) pool = Pool(os.cpu_count()) result = list() - for idx, hypothesis in enumerate(self.test_text[:self.sample_size]): - result.append(pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight))) + for idx, hypothesis in enumerate(self.test_text[: self.sample_size]): + result.append( + pool.apply_async(self.cal_bleu, args=(reference, hypothesis, weight)) + ) score = 0.0 cnt = 0 for i in result: diff --git a/metrics/clas_acc.py b/metrics/clas_acc.py index b04e1b80..bed1f594 100644 --- a/metrics/clas_acc.py +++ b/metrics/clas_acc.py @@ -14,7 +14,7 @@ class ACC(Metrics): def __init__(self, weight, if_use=True, gpu=True): - super(ACC, self).__init__('clas_acc', weight, if_use) + super(ACC, self).__init__("clas_acc", weight, if_use) self.if_use = if_use self.model = None @@ -30,7 +30,7 @@ def calculate_metric(self, model, data_loader): total_num = 0 with torch.no_grad(): for i, data in enumerate(self.data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if self.gpu: inp, target = inp.cuda(), target.cuda() diff --git a/metrics/dummy.py b/metrics/dummy.py index 30a82a68..5e1a07d6 100644 --- a/metrics/dummy.py +++ b/metrics/dummy.py @@ -5,8 +5,9 @@ class Dummy(Metrics): """ Dummy score to make Overal score positive and easy to read """ + def __init__(self, name=None, weight=1, value=5, if_use=True): - super(Dummy, self).__init__('Dummy', weight, if_use) + super(Dummy, self).__init__("Dummy", weight, if_use) self.value = 5 def calculate_metric(self): diff --git a/metrics/gpt_nll.py b/metrics/gpt_nll.py index 64de2404..47761995 100644 --- a/metrics/gpt_nll.py +++ b/metrics/gpt_nll.py @@ -14,7 +14,7 @@ class GPTNLL(Metrics): def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): - super(GPTNLL, self).__init__('GPT2 as oracle', weight, if_use) + super(GPTNLL, self).__init__("GPT2 as oracle", weight, if_use) self.if_use = if_use self.test_text = test_text @@ -22,23 +22,31 @@ def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=Tru self.NLLloss = torch.nn.NLLLoss() self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.model = GPT2LMHeadModel.from_pretrained("gpt2") - print('Calculating dataset NLL') - self.real_text_nll = self.calcualte_NLL(random.sample(real_text.tokens, 500)) if real_text else None + print("Calculating dataset NLL") + self.real_text_nll = ( + self.calcualte_NLL(random.sample(real_text.tokens, 500)) + if real_text + else None + ) if self.real_text_nll: - print(f'dataset NLL based on GPT2 is {self.real_text_nll}') - print('GPT2 as oracle metric will be calculated relative to this value') + print(f"dataset NLL based on GPT2 is {self.real_text_nll}") + print("GPT2 as oracle metric will be calculated relative to this value") def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text is not None else self.test_text - self.real_text_nll = self.calcualte_NLL(real_text.tokens) if real_text is not None else self.real_text_nll + self.real_text_nll = ( + self.calcualte_NLL(real_text.tokens) + if real_text is not None + else self.real_text_nll + ) def calculate_metric(self): """Get gpt2 NLL score difference with dataset NLL.""" return self.calcualte_NLL(self.test_text) - self.real_text_nll def calcualte_NLL(self, messages): - if type(messages[0]) == list: # we received list of tokens - messages = [' '.join(msg) for msg in messages] + if type(messages[0]) == list: # we received list of tokens + messages = [" ".join(msg) for msg in messages] all_logits = [] for message in messages: diff --git a/metrics/ioc.py b/metrics/ioc.py index 7e95360c..4cf3f81b 100644 --- a/metrics/ioc.py +++ b/metrics/ioc.py @@ -10,29 +10,30 @@ class IOC(Metrics): def __init__(self, weight, name=None, test_text=None, real_text=None, if_use=True): - super(IOC, self).__init__('Index of Coincidence', weight, if_use) + super(IOC, self).__init__("Index of Coincidence", weight, if_use) self.if_use = if_use self.test_text = test_text self.real_text_ioc = self.calculate_ioc(real_text.tokens) if real_text else None if self.real_text_ioc: - print(f'Dataset Index of coincidence: {self.real_text_ioc}') + print(f"Dataset Index of coincidence: {self.real_text_ioc}") self.reference = None self.is_first = True def _reset(self, test_text=None, real_text=None): self.test_text = test_text if test_text is not None else self.test_text - self.real_text_ioc = self.get_ioc(real_text.tokens) if real_text is not None else self.real_text_ioc + self.real_text_ioc = ( + self.get_ioc(real_text.tokens) + if real_text is not None + else self.real_text_ioc + ) def calculate_metric(self): return self.calculate_ioc(self.test_text) / self.real_text_ioc def calculate_ioc(self, tokenized_text): """Index Of coincidence: probability of 2 random tokens in text to equal.""" - tokenized_text = [ - [str(token) for token in tokens] - for tokens in tokenized_text - ] + tokenized_text = [[str(token) for token in tokens] for tokens in tokenized_text] tokens = list(chain(*tokenized_text)) counts = Counter(tokens) total = sum(ni * (ni - 1) for ni in counts.values()) diff --git a/metrics/nll.py b/metrics/nll.py index c0874ac4..78f19f4d 100644 --- a/metrics/nll.py +++ b/metrics/nll.py @@ -28,10 +28,13 @@ def __init__(self, name, weight, if_use=False, gpu=False): def calculate_metric(self): """note that NLL score need the updated model and data loader each time, use reset() before get_score()""" if self.leak_dis is not None: # For LeakGAN - return self.cal_nll_with_leak_dis(self.model, self.data_loader, self.leak_dis, self.gpu) + return self.cal_nll_with_leak_dis( + self.model, self.data_loader, self.leak_dis, self.gpu + ) if self.label_i is not None: # For category text generation - return self.cal_nll_with_label(self.model, self.data_loader, self.label_i, - self.criterion, self.gpu) + return self.cal_nll_with_label( + self.model, self.data_loader, self.label_i, self.criterion, self.gpu + ) return self.cal_nll(self.model, self.data_loader, self.criterion, self.gpu) def _reset(self, model=None, data_loader=None, label_i=None, leak_dis=None): @@ -46,7 +49,7 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if gpu: inp, target = inp.cuda(), target.cuda() @@ -59,17 +62,17 @@ def cal_nll(model, data_loader, criterion, gpu=cfg.CUDA): @staticmethod def cal_nll_with_label(model, data_loader, label_i, criterion, gpu=cfg.CUDA): """NLL score for category text generation model.""" - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] label = torch.LongTensor([label_i] * data_loader.batch_size) if gpu: inp, target, label = inp.cuda(), target.cuda(), label.cuda() hidden = model.init_hidden(data_loader.batch_size) - if model.name == 'oracle': + if model.name == "oracle": pred = model.forward(inp, hidden) else: pred = model.forward(inp, hidden, label) @@ -83,7 +86,7 @@ def cal_nll_with_leak_dis(model, data_loader, leak_dis, gpu=cfg.CUDA): total_loss = 0 with torch.no_grad(): for i, data in enumerate(data_loader): - inp, target = data['input'], data['target'] + inp, target = data["input"], data["target"] if gpu: inp, target = inp.cuda(), target.cuda() diff --git a/metrics/ppl.py b/metrics/ppl.py index 7cafabdf..0fc986af 100644 --- a/metrics/ppl.py +++ b/metrics/ppl.py @@ -17,7 +17,7 @@ from metrics.basic import Metrics from utils.text_process import write_tokens -kenlm_path = '/home/zhiwei/kenlm' # specify the kenlm path +kenlm_path = "/home/zhiwei/kenlm" # specify the kenlm path class PPL(Metrics): @@ -30,7 +30,7 @@ def __init__(self, train_data, test_data, weight, n_gram=5, if_use=False): @param n_gram: calculate with n-gram @param if_use: if use """ - super(PPL, self).__init__('[PPL-F, PPL-R]', weight, if_use) + super(PPL, self).__init__("[PPL-F, PPL-R]", weight, if_use) self.n_gram = n_gram self.if_use = if_use @@ -43,21 +43,33 @@ def _reset(self, gen_tokens=None): self.gen_tokens = gen_tokens def calculate_metric(self): - save_path = os.path.join("/tmp", ''.join(random.choice( - string.ascii_uppercase + string.digits) for _ in range(6))) + save_path = os.path.join( + "/tmp", + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(6) + ), + ) output_path = save_path + ".arpa" write_tokens(save_path, self.gen_tokens) # save to file # forward ppl - for_lm = self.train_ngram_lm(kenlm_path=kenlm_path, data_path=cfg.test_data, - output_path=output_path, n_gram=self.n_gram) + for_lm = self.train_ngram_lm( + kenlm_path=kenlm_path, + data_path=cfg.test_data, + output_path=output_path, + n_gram=self.n_gram, + ) for_ppl = self.get_ppl(for_lm, self.gen_tokens) # reverse ppl try: - rev_lm = self.train_ngram_lm(kenlm_path=kenlm_path, data_path=save_path, - output_path=output_path, n_gram=self.n_gram) + rev_lm = self.train_ngram_lm( + kenlm_path=kenlm_path, + data_path=save_path, + output_path=output_path, + n_gram=self.n_gram, + ) rev_ppl = self.get_ppl(rev_lm, self.test_data.tokens) except: @@ -76,14 +88,21 @@ def train_ngram_lm(self, kenlm_path, data_path, output_path, n_gram): # create .arpa and .bin file of n-grams curdir = os.path.abspath(os.path.curdir) - cd_command = "cd " + os.path.join(kenlm_path, 'build') - command_1 = "bin/lmplz -o {} <{} >{} --discount_fallback &".format(str(n_gram), os.path.join(curdir, data_path), - output_path) - command_2 = "bin/build_binary -s {} {} &".format(output_path, output_path + ".bin") + cd_command = "cd " + os.path.join(kenlm_path, "build") + command_1 = "bin/lmplz -o {} <{} >{} --discount_fallback &".format( + str(n_gram), os.path.join(curdir, data_path), output_path + ) + command_2 = "bin/build_binary -s {} {} &".format( + output_path, output_path + ".bin" + ) while True: - subprocess.getstatusoutput(cd_command + " && " + command_1) # call without logging output - subprocess.getstatusoutput(cd_command + " && " + command_2) # call without logging output + subprocess.getstatusoutput( + cd_command + " && " + command_1 + ) # call without logging output + subprocess.getstatusoutput( + cd_command + " && " + command_2 + ) # call without logging output if os.path.exists(output_path + ".bin"): break @@ -99,8 +118,14 @@ def get_ppl(self, lm, tokens): total_nll = 0 total_wc = 0 for words in tokens: - nll = np.sum([-math.log(math.pow(10.0, score)) - for score, _, _ in lm.full_scores(' '.join(words), bos=True, eos=False)]) + nll = np.sum( + [ + -math.log(math.pow(10.0, score)) + for score, _, _ in lm.full_scores( + " ".join(words), bos=True, eos=False + ) + ] + ) total_wc += len(words) total_nll += nll ppl = np.exp(total_nll / total_wc) diff --git a/models/Oracle.py b/models/Oracle.py index 4c3cbb8b..2b0a2a6f 100644 --- a/models/Oracle.py +++ b/models/Oracle.py @@ -13,10 +13,13 @@ class Oracle(LSTMGenerator): - - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(Oracle, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'oracle' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(Oracle, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "oracle" # initialise oracle network with N(0,1) # otherwise variance of initialisation is very small => high NLL for loader sampled from the same model diff --git a/models/discriminators/CatGAN_D.py b/models/discriminators/CatGAN_D.py index 2e841b71..1e977c6c 100644 --- a/models/discriminators/CatGAN_D.py +++ b/models/discriminators/CatGAN_D.py @@ -21,9 +21,25 @@ # Discriminator class CatGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(CatGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CatGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) self.embed_dim = embed_dim self.max_seq_len = max_seq_len @@ -32,10 +48,14 @@ def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) self.highway = nn.Linear(self.feature_dim, self.feature_dim) self.feature2out = nn.Linear(self.feature_dim, 100) # origin @@ -50,15 +70,26 @@ def forward(self, inp): :param inp: batch_size * seq_len * vocab_size :return logits: [batch_size * num_rep] (1-D tensor) """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim pred = self.feature2out(self.dropout(pred)) logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] @@ -68,9 +99,29 @@ def forward(self, inp): # Classifier class CatGAN_C(CNNClassifier): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(CatGAN_C, self).__init__(k_label, embed_dim, max_seq_len, num_rep, vocab_size, clas_filter_sizes, - clas_num_filters, padding_idx, gpu, dropout) + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CatGAN_C, self).__init__( + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + clas_filter_sizes, + clas_num_filters, + padding_idx, + gpu, + dropout, + ) # Use Glove # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/discriminators/CoT_D.py b/models/discriminators/CoT_D.py index c008a39d..0eab5e3b 100644 --- a/models/discriminators/CoT_D.py +++ b/models/discriminators/CoT_D.py @@ -14,8 +14,12 @@ class Cot_D(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(Cot_D, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(Cot_D, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) def get_pred(self, input, target): pred = self.forward(input, self.init_hidden(input.size(0))) diff --git a/models/discriminators/DPGAN_D.py b/models/discriminators/DPGAN_D.py index ace54904..77c9a527 100644 --- a/models/discriminators/DPGAN_D.py +++ b/models/discriminators/DPGAN_D.py @@ -16,9 +16,13 @@ class DPGAN_D(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DPGAN_D, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dpgan_d' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DPGAN_D, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dpgan_d" def getReward(self, samples): """ @@ -30,7 +34,9 @@ def getReward(self, samples): hidden = self.init_hidden(batch_size) pred = self.forward(inp, hidden) - word_reward = F.nll_loss(pred, target.view(-1), reduction='none').view(batch_size, -1) + word_reward = F.nll_loss(pred, target.view(-1), reduction="none").view( + batch_size, -1 + ) sentence_reward = torch.mean(word_reward, dim=-1, keepdim=True) return word_reward, sentence_reward diff --git a/models/discriminators/EvoGAN_D.py b/models/discriminators/EvoGAN_D.py index 7d13ed95..6f64f8e6 100644 --- a/models/discriminators/EvoGAN_D.py +++ b/models/discriminators/EvoGAN_D.py @@ -18,9 +18,25 @@ class EvoGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(EvoGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(EvoGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) self.embed_dim = embed_dim self.max_seq_len = max_seq_len @@ -29,10 +45,14 @@ def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) self.highway = nn.Linear(self.feature_dim, self.feature_dim) self.feature2out = nn.Linear(self.feature_dim, 100) # origin @@ -47,15 +67,26 @@ def forward(self, inp): :param inp: batch_size * seq_len * vocab_size :return logits: [batch_size * num_rep] (1-D tensor) """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] pred = torch.cat(pools, 1) # batch_size * feature_dim * num_rep - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim pred = self.feature2out(self.dropout(pred)) logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] diff --git a/models/discriminators/FixemGAN_D.py b/models/discriminators/FixemGAN_D.py index 6eea6ae8..dd0a6f6f 100644 --- a/models/discriminators/FixemGAN_D.py +++ b/models/discriminators/FixemGAN_D.py @@ -1,9 +1,16 @@ import torch.nn as nn import config as cfg -from utils.nn_helpers import get_optimizer, MyConvLayer, MyTransformerEncoderLayer, Flatten, Dummy +from utils.nn_helpers import ( + get_optimizer, + MyConvLayer, + MyTransformerEncoderLayer, + Flatten, + Dummy, +) from models.discriminators.discriminator import CNNDiscriminator + class Discriminator(nn.Module): def __init__(self, complexity): super(Discriminator, self).__init__() @@ -13,7 +20,9 @@ def __init__(self, complexity): self.main = nn.Sequential( # 1 layer - MyConvLayer(cfg.w2v_embedding_size, complexity, alpha=alpha, drop_rate=drop_rate), + MyConvLayer( + cfg.w2v_embedding_size, complexity, alpha=alpha, drop_rate=drop_rate + ), # 2 layer MyConvLayer( complexity, @@ -27,14 +36,12 @@ def __init__(self, complexity): # 4 layer MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), # 5 layer - MyTransformerEncoderLayer( d_model=complexity, n_layers=3, ) if include_transformer else Dummy(), - # 6 layer MyConvLayer(complexity, complexity, alpha=alpha, drop_rate=drop_rate), # MyLSTMLayer(complexity, complexity//2), @@ -47,7 +54,6 @@ def __init__(self, complexity): alpha=alpha, drop_rate=drop_rate, ), - MyConvLayer( complexity, complexity, diff --git a/models/discriminators/LeakGAN_D.py b/models/discriminators/LeakGAN_D.py index 86c46800..f5dc3840 100644 --- a/models/discriminators/LeakGAN_D.py +++ b/models/discriminators/LeakGAN_D.py @@ -15,5 +15,12 @@ class LeakGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2): - super(LeakGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) + super(LeakGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/discriminators/MaliGAN_D.py b/models/discriminators/MaliGAN_D.py index b132c8d1..37b3cfac 100644 --- a/models/discriminators/MaliGAN_D.py +++ b/models/discriminators/MaliGAN_D.py @@ -15,5 +15,12 @@ class MaliGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(MaliGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) + super(MaliGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/discriminators/RelGAN_D.py b/models/discriminators/RelGAN_D.py index 88435d5a..e3a109bd 100644 --- a/models/discriminators/RelGAN_D.py +++ b/models/discriminators/RelGAN_D.py @@ -18,9 +18,25 @@ class RelGAN_D(CNNDiscriminator): - def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(RelGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, - gpu, dropout) + def __init__( + self, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(RelGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) self.embed_dim = embed_dim self.max_seq_len = max_seq_len @@ -29,10 +45,14 @@ def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in - zip(dis_num_filters, dis_filter_sizes) - ]) + self.convs = nn.ModuleList( + [ + nn.Conv2d( + 1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single) + ) + for (n, f) in zip(dis_num_filters, dis_filter_sizes) + ] + ) self.highway = nn.Linear(self.feature_dim, self.feature_dim) self.feature2out = nn.Linear(self.feature_dim, 100) @@ -47,14 +67,25 @@ def forward(self, inp): :param inp: batch_size * seq_len * vocab_size :return logits: [batch_size * num_rep] (1-D tensor) """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim - cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] - pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] + cons = [ + F.relu(conv(emb)) for conv in self.convs + ] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] + pools = [ + F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons + ] # [batch_size * num_filter * num_rep] pred = torch.cat(pools, 1) - pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # (batch_size * num_rep) * feature_dim + pred = ( + pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) + ) # (batch_size * num_rep) * feature_dim highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway pred = self.feature2out(self.dropout(pred)) logits = self.out2logits(pred).squeeze(1) # [batch_size * num_rep] diff --git a/models/discriminators/SentiGAN_D.py b/models/discriminators/SentiGAN_D.py index f6668ec4..f54ae108 100644 --- a/models/discriminators/SentiGAN_D.py +++ b/models/discriminators/SentiGAN_D.py @@ -19,9 +19,18 @@ class SentiGAN_D(CNNDiscriminator): - def __init__(self, k_label, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2): - super(SentiGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) + def __init__( + self, k_label, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.2 + ): + super(SentiGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) self.feature2out = nn.Linear(self.feature_dim, k_label + 1) @@ -30,9 +39,29 @@ def __init__(self, k_label, embed_dim, vocab_size, padding_idx, gpu=False, dropo # Classifier class SentiGAN_C(CNNClassifier): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(SentiGAN_C, self).__init__(k_label, embed_dim, max_seq_len, num_rep, vocab_size, clas_filter_sizes, - clas_num_filters, padding_idx, gpu, dropout) + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(SentiGAN_C, self).__init__( + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + clas_filter_sizes, + clas_num_filters, + padding_idx, + gpu, + dropout, + ) # Use Glove # self.embeddings.from_pretrained(build_embedding_matrix(cfg.dataset)) diff --git a/models/discriminators/SeqGAN_D.py b/models/discriminators/SeqGAN_D.py index 70de10d1..25ccad90 100644 --- a/models/discriminators/SeqGAN_D.py +++ b/models/discriminators/SeqGAN_D.py @@ -15,5 +15,12 @@ class SeqGAN_D(CNNDiscriminator): def __init__(self, embed_dim, vocab_size, padding_idx, gpu=False, dropout=0.25): - super(SeqGAN_D, self).__init__(embed_dim, vocab_size, dis_filter_sizes, dis_num_filters, padding_idx, gpu, - dropout) + super(SeqGAN_D, self).__init__( + embed_dim, + vocab_size, + dis_filter_sizes, + dis_num_filters, + padding_idx, + gpu, + dropout, + ) diff --git a/models/discriminators/discriminator.py b/models/discriminators/discriminator.py index 77738b7d..b3331977 100644 --- a/models/discriminators/discriminator.py +++ b/models/discriminators/discriminator.py @@ -18,8 +18,16 @@ class CNNDiscriminator(nn.Module): - def __init__(self, embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, gpu=False, - dropout=0.2): + def __init__( + self, + embed_dim, + vocab_size, + filter_sizes, + num_filters, + padding_idx, + gpu=False, + dropout=0.2, + ): super(CNNDiscriminator, self).__init__() self.embedding_dim = embed_dim self.vocab_size = vocab_size @@ -28,9 +36,12 @@ def __init__(self, embed_dim, vocab_size, filter_sizes, num_filters, padding_idx self.gpu = gpu self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes) - ]) + self.convs = nn.ModuleList( + [ + nn.Conv2d(1, n, (f, embed_dim)) + for (n, f) in zip(num_filters, filter_sizes) + ] + ) self.highway = nn.Linear(self.feature_dim, self.feature_dim) self.feature2out = nn.Linear(self.feature_dim, 2) self.dropout = nn.Dropout(dropout) @@ -54,12 +65,21 @@ def get_feature(self, inp): :param inp: batch_size * max_seq_len :return: batch_size * feature_dim """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim - convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * length] - pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim + convs = [ + F.relu(conv(emb)).squeeze(3) for conv in self.convs + ] # [batch_size * num_filter * length] + pools = [ + F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs + ] # [batch_size * num_filter] pred = torch.cat(pools, 1) # tensor: batch_size * feature_dim highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway return pred @@ -67,18 +87,26 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.dis_init == 'uniform': + if cfg.dis_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.dis_init == 'normal': + elif cfg.dis_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.dis_init == 'truncated_normal': + elif cfg.dis_init == "truncated_normal": truncated_normal_(param, std=stddev) class GRUDiscriminator(nn.Module): - - def __init__(self, embedding_dim, vocab_size, hidden_dim, feature_dim, max_seq_len, padding_idx, - gpu=False, dropout=0.2): + def __init__( + self, + embedding_dim, + vocab_size, + hidden_dim, + feature_dim, + max_seq_len, + padding_idx, + gpu=False, + dropout=0.2, + ): super(GRUDiscriminator, self).__init__() self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -86,8 +114,12 @@ def __init__(self, embedding_dim, vocab_size, hidden_dim, feature_dim, max_seq_l self.padding_idx = padding_idx self.gpu = gpu - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + self.gru = nn.GRU( + embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout + ) self.gru2hidden = nn.Linear(2 * 2 * hidden_dim, feature_dim) self.feature2out = nn.Linear(feature_dim, 2) self.dropout = nn.Dropout(dropout) @@ -125,7 +157,9 @@ def get_feature(self, inp): emb = emb.permute(1, 0, 2) # seq_len * batch_size * embedding_dim _, hidden = self.gru(emb, hidden) # 4 * batch_size * hidden_dim hidden = hidden.permute(1, 0, 2).contiguous() # batch_size * 4 * hidden_dim - out = self.gru2hidden(hidden.view(-1, 4 * self.hidden_dim)) # batch_size * 4 * hidden_dim + out = self.gru2hidden( + hidden.view(-1, 4 * self.hidden_dim) + ) # batch_size * 4 * hidden_dim feature = torch.tanh(out) # batch_size * feature_dim return feature @@ -134,20 +168,32 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.dis_init == 'uniform': + if cfg.dis_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.dis_init == 'normal': + elif cfg.dis_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.dis_init == 'truncated_normal': + elif cfg.dis_init == "truncated_normal": truncated_normal_(param, std=stddev) # Classifier class CNNClassifier(CNNDiscriminator): - def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, filter_sizes, num_filters, padding_idx, - gpu=False, dropout=0.25): - super(CNNClassifier, self).__init__(embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, - gpu, dropout) + def __init__( + self, + k_label, + embed_dim, + max_seq_len, + num_rep, + vocab_size, + filter_sizes, + num_filters, + padding_idx, + gpu=False, + dropout=0.25, + ): + super(CNNClassifier, self).__init__( + embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, gpu, dropout + ) self.k_label = k_label self.embed_dim = embed_dim @@ -157,9 +203,12 @@ def __init__(self, k_label, embed_dim, max_seq_len, num_rep, vocab_size, filter_ self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) - self.convs = nn.ModuleList([ - nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes) - ]) # vanilla + self.convs = nn.ModuleList( + [ + nn.Conv2d(1, n, (f, embed_dim)) + for (n, f) in zip(num_filters, filter_sizes) + ] + ) # vanilla # self.convs = nn.ModuleList([ # nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in # zip(num_filters, filter_sizes) @@ -179,11 +228,17 @@ def forward(self, inp): :param inp: batch_size * seq_len * vocab_size :return logits: [batch_size * num_rep] (1-D tensor) """ - emb = self.embeddings(inp).unsqueeze(1) # batch_size * 1 * max_seq_len * embed_dim + emb = self.embeddings(inp).unsqueeze( + 1 + ) # batch_size * 1 * max_seq_len * embed_dim # vanilla - convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * length] - pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] + convs = [ + F.relu(conv(emb)).squeeze(3) for conv in self.convs + ] # [batch_size * num_filter * length] + pools = [ + F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs + ] # [batch_size * num_filter] # RelGAN # cons = [F.relu(conv(emb)) for conv in self.convs] # [batch_size * num_filter * (seq_len-k_h+1) * num_rep] # pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in cons] # [batch_size * num_filter * num_rep] @@ -191,10 +246,15 @@ def forward(self, inp): pred = torch.cat(pools, 1) # batch_size * feature_dim # pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim) # RelGAN highway = self.highway(pred) - pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred # highway, same dim + pred = ( + torch.sigmoid(highway) * F.relu(highway) + + (1.0 - torch.sigmoid(highway)) * pred + ) # highway, same dim pred = self.feature2out(self.dropout(pred)) - logits = self.out2logits(self.dropout(pred)).squeeze(1) # vanilla, batch_size * k_label + logits = self.out2logits(self.dropout(pred)).squeeze( + 1 + ) # vanilla, batch_size * k_label # logits = self.out2logits(self.dropout(pred.view(inp.size(0), -1))).squeeze(1) # RelGAN, batch_size * k_label return logits diff --git a/models/generators/CatGAN_G.py b/models/generators/CatGAN_G.py index c4e2f7a6..77743e5f 100644 --- a/models/generators/CatGAN_G.py +++ b/models/generators/CatGAN_G.py @@ -17,35 +17,58 @@ class CatGAN_G(LSTMGenerator): - def __init__(self, k_label, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, - padding_idx, - gpu=False): - super(CatGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'catgan' + def __init__( + self, + k_label, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(CatGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "catgan" self.k_label = k_label - self.temperature = nn.Parameter(torch.Tensor([1.0]), requires_grad=False) # init value is 1.0 + self.temperature = nn.Parameter( + torch.Tensor([1.0]), requires_grad=False + ) # init value is 1.0 # Category matrix # self.cat_mat = nn.Parameter(torch.rand(self.k_label, embedding_dim), requires_grad=True) self.cat_mat = nn.Parameter(torch.eye(k_label), requires_grad=False) - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim - self.lstm = nn.LSTM(k_label + embedding_dim, self.hidden_dim, batch_first=True) + self.lstm = nn.LSTM( + k_label + embedding_dim, self.hidden_dim, batch_first=True + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=k_label + embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=k_label + embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -67,17 +90,25 @@ def forward(self, inp, hidden, label=None, need_hidden=False): :param hidden: memory size :param need_hidden: if return hidden, use for sampling """ - assert type(label) == torch.Tensor, 'missing label' + assert type(label) == torch.Tensor, "missing label" emb = self.embeddings(inp) # batch_size * len * embedding_dim # cat category vector label_onehot = F.one_hot(label, self.k_label).float() # batch_size * k_label - label_onehot_ex = label_onehot.unsqueeze(1).expand(-1, inp.size(1), -1) # batch_size * len * k_label - label_vec = torch.bmm(label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1)) # batch_size * len * embed_dim - emb = torch.cat((emb, label_vec), dim=-1) # batch_sie * len * (k_label + embed_dim) + label_onehot_ex = label_onehot.unsqueeze(1).expand( + -1, inp.size(1), -1 + ) # batch_size * len * k_label + label_vec = torch.bmm( + label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1) + ) # batch_size * len * embed_dim + emb = torch.cat( + (emb, label_vec), dim=-1 + ) # batch_sie * len * (k_label + embed_dim) out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # batch_size * seq_len * vocab_size # out = self.temperature * out # temperature pred = self.softmax(out) @@ -98,14 +129,20 @@ def step(self, inp, hidden, label=None): - hidden: next hidden - next_token: [batch_size], next sentence token """ - assert type(label) == torch.Tensor, 'missing label' + assert type(label) == torch.Tensor, "missing label" emb = self.embeddings(inp).unsqueeze(1) # cat category vector label_onehot = F.one_hot(label, self.k_label).float() # batch_size * k_label - label_onehot_ex = label_onehot.unsqueeze(1).expand(-1, 1, -1) # batch_size * 1 * k_label - label_vec = torch.bmm(label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1)) # batch_size * 1 * embed_dim - emb = torch.cat((emb, label_vec), dim=-1) # batch_sie * len * (k_label + embed_dim) + label_onehot_ex = label_onehot.unsqueeze(1).expand( + -1, 1, -1 + ) # batch_size * 1 * k_label + label_vec = torch.bmm( + label_onehot_ex, self.cat_mat.expand(inp.size(0), -1, -1) + ) # batch_size * 1 * embed_dim + emb = torch.cat( + (emb, label_vec), dim=-1 + ) # batch_sie * len * (k_label + embed_dim) out, hidden = self.lstm(emb, hidden) gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1))) @@ -115,8 +152,14 @@ def step(self, inp, hidden, label=None): return pred, hidden, next_token - def sample(self, num_samples, batch_size, one_hot=False, label_i=None, - start_letter=cfg.start_letter): + def sample( + self, + num_samples, + batch_size, + one_hot=False, + label_i=None, + start_letter=cfg.start_letter, + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -126,7 +169,7 @@ def sample(self, num_samples, batch_size, one_hot=False, label_i=None, - samples: all samples """ global all_preds - assert type(label_i) == int, 'missing label' + assert type(label_i) == int, "missing label" num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1 samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long() if one_hot: @@ -144,7 +187,7 @@ def sample(self, num_samples, batch_size, one_hot=False, label_i=None, for i in range(self.max_seq_len): pred, hidden, next_token = self.step(inp, hidden, label_t) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/generators/CoT_G.py b/models/generators/CoT_G.py index 50999e8f..ab294c3c 100644 --- a/models/generators/CoT_G.py +++ b/models/generators/CoT_G.py @@ -14,9 +14,13 @@ class CoT_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(CoT_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'cot' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(CoT_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "cot" def get_loss(self, input, rewards): """ @@ -25,7 +29,9 @@ def get_loss(self, input, rewards): @param rewards: rewards form mediator, (batch size * seq_len) * vocab_size @return: """ - log_pred = self.forward(input, self.init_hidden(input.size(0))) # (batch_size * seq_len) * vocab_size + log_pred = self.forward( + input, self.init_hidden(input.size(0)) + ) # (batch_size * seq_len) * vocab_size g_pred = torch.exp(log_pred) loss = -torch.sum(g_pred * (rewards - log_pred)) / rewards.size(0) return loss diff --git a/models/generators/DGSAN_G.py b/models/generators/DGSAN_G.py index 5deb53e8..04491282 100644 --- a/models/generators/DGSAN_G.py +++ b/models/generators/DGSAN_G.py @@ -11,6 +11,10 @@ class DGSAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DGSAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dgsan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DGSAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dgsan" diff --git a/models/generators/DPGAN_G.py b/models/generators/DPGAN_G.py index f9d415b7..c32d9dca 100644 --- a/models/generators/DPGAN_G.py +++ b/models/generators/DPGAN_G.py @@ -14,9 +14,13 @@ class DPGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(DPGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'dpgan_g' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(DPGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "dpgan_g" def sample_teacher_forcing(self, inp): """ @@ -32,7 +36,9 @@ def sample_teacher_forcing(self, inp): pred = self.forward(inp, hidden) samples = torch.argmax(pred, dim=-1).view(batch_size, -1) - log_prob = F.nll_loss(pred, samples.view(-1), reduction='none').view(batch_size, -1) + log_prob = F.nll_loss(pred, samples.view(-1), reduction="none").view( + batch_size, -1 + ) # samples = torch.multinomial(torch.exp(log_prob), 1) return samples, log_prob diff --git a/models/generators/EvoGAN_G.py b/models/generators/EvoGAN_G.py index 8bb2a358..b195a504 100644 --- a/models/generators/EvoGAN_G.py +++ b/models/generators/EvoGAN_G.py @@ -17,15 +17,31 @@ class EvoGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(EvoGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'evogan' - - self.temperature = nn.Parameter(torch.Tensor([1.0]), requires_grad=False) # init value is 1.0 - - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(EvoGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "evogan" + + self.temperature = nn.Parameter( + torch.Tensor([1.0]), requires_grad=False + ) # init value is 1.0 + + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, batch_first=True) @@ -33,14 +49,19 @@ def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, v else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -79,7 +100,9 @@ def step(self, inp, hidden): return pred, hidden, next_token, next_token_onehot, next_o - def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter): + def sample( + self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -103,7 +126,7 @@ def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_ for i in range(self.max_seq_len): pred, hidden, next_token, _, _ = self.step(inp, hidden) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/generators/FixemGAN_G.py b/models/generators/FixemGAN_G.py index c83dc7c1..2c72f754 100644 --- a/models/generators/FixemGAN_G.py +++ b/models/generators/FixemGAN_G.py @@ -18,10 +18,15 @@ from models.generators.generator import LSTMGenerator - class Generator(LSTMGenerator): def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): - super(Generator, self).__init__(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.target_len, cfg.padding_idx) + super(Generator, self).__init__( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.target_len, + cfg.padding_idx, + ) alpha = 0.2 added_dim_pe = 0 include_batch_norm = True @@ -34,7 +39,9 @@ def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): self.main = nn.Sequential( # 1 layer Concatenate(1), - nn.Linear(cfg.noise_size + cfg.k_label, cfg.target_len // 2 // 2 * complexity), + nn.Linear( + cfg.noise_size + cfg.k_label, cfg.target_len // 2 // 2 * complexity + ), nn.BatchNorm1d(cfg.target_len // 2 // 2 * complexity), nn.LeakyReLU(alpha), Reshape(complexity, cfg.target_len // 2 // 2), @@ -94,9 +101,10 @@ def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): # 8 layer MyLSTMLayerNorm( complexity, - complexity//2, - ) if include_lstm else Dummy(), - + complexity // 2, + ) + if include_lstm + else Dummy(), # 9 layer MyConvTransposeLayer( complexity, @@ -107,8 +115,10 @@ def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): # 10 layer MyLSTMLayerNorm( complexity, - complexity//2, - ) if include_lstm else Dummy(), + complexity // 2, + ) + if include_lstm + else Dummy(), # 11 layer MyConvTransposeLayer( complexity, @@ -128,12 +138,16 @@ def __init__(self, complexity, noise_size, w2v, w2v_embedding_size): self.optimizer = get_optimizer(self.parameters()) def forward(self, noise, target_labels): - target_labels = torch.nn.functional.one_hot(target_labels, num_classes=cfg.k_label) + target_labels = torch.nn.functional.one_hot( + target_labels, num_classes=cfg.k_label + ) return self.main([noise, target_labels]) - def sample(self, num_samples, batch_size, label_i = 'random', start_letter=cfg.start_letter): + def sample( + self, num_samples, batch_size, label_i="random", start_letter=cfg.start_letter + ): noise = create_noise(num_samples, self.noise_size, cfg.k_label) - if label_i != 'random': + if label_i != "random": noise = (noise[0], torch.tensor(label_i).expand_as(noise[1])) if cfg.CUDA: diff --git a/models/generators/JSDGAN_G.py b/models/generators/JSDGAN_G.py index 32c6432c..a84d58c7 100644 --- a/models/generators/JSDGAN_G.py +++ b/models/generators/JSDGAN_G.py @@ -15,10 +15,22 @@ class JSDGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(JSDGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'jsdgan' + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(JSDGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "jsdgan" # RMC @@ -43,8 +55,12 @@ def JSD_loss(self, inp, target): """ batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - pred = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + pred = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(pred * target_onehot, dim=-1) # batch_size * seq_len # calculate probabilities of sentences @@ -55,19 +71,26 @@ def JSD_loss(self, inp, target): prob_data = prob_data.cuda() # calculate the reward - reward = torch.log(1. - torch.div(prob_data, prob_data + prob_gen)) # batch_size + reward = torch.log( + 1.0 - torch.div(prob_data, prob_data + prob_gen) + ) # batch_size # check if nan if torch.isnan(reward).sum() > 0: - print('Reward is nan!!!') + print("Reward is nan!!!") exit(1) - loss = torch.sum((prob_gen * reward).detach() * torch.sum(pred.double(), dim=-1)) + loss = torch.sum( + (prob_gen * reward).detach() * torch.sum(pred.double(), dim=-1) + ) return loss def min_max_normal(self, prob): - return torch.div(prob - torch.min(prob), torch.clamp(torch.max(prob) - torch.min(prob), min=1e-78)) + return torch.div( + prob - torch.min(prob), + torch.clamp(torch.max(prob) - torch.min(prob), min=1e-78), + ) def sigmoid_normal(self, prob): """push prob either close to 0 or 1""" diff --git a/models/generators/LeakGAN_G.py b/models/generators/LeakGAN_G.py index bdde5bb2..53e69969 100644 --- a/models/generators/LeakGAN_G.py +++ b/models/generators/LeakGAN_G.py @@ -4,7 +4,7 @@ # @FileName : LeakGAN_G.py # @Time : Created at 2019-04-25 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import math import time @@ -21,10 +21,19 @@ class LeakGAN_G(nn.Module): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, goal_size, - step_size, gpu=False): + def __init__( + self, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + goal_size, + step_size, + gpu=False, + ): super(LeakGAN_G, self).__init__() - self.name = 'leakgan' + self.name = "leakgan" self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -37,7 +46,9 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.gpu = gpu self.temperature = 1.5 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) self.worker = nn.LSTM(embedding_dim, hidden_dim) self.manager = nn.LSTM(goal_out_size, hidden_dim) @@ -49,7 +60,17 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.init_params() - def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log=False, train=False): + def forward( + self, + idx, + inp, + work_hidden, + mana_hidden, + feature, + real_goal, + no_log=False, + train=False, + ): """ Embeds input and sample on token at a time (seq_len = 1) @@ -69,16 +90,25 @@ def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log emb = self.embeddings(inp).unsqueeze(0) # 1 * batch_size * embed_dim # Manager - mana_out, mana_hidden = self.manager(feature, mana_hidden) # mana_out: 1 * batch_size * hidden_dim - mana_out = self.mana2goal(mana_out.permute([1, 0, 2])) # batch_size * 1 * goal_out_size + mana_out, mana_hidden = self.manager( + feature, mana_hidden + ) # mana_out: 1 * batch_size * hidden_dim + mana_out = self.mana2goal( + mana_out.permute([1, 0, 2]) + ) # batch_size * 1 * goal_out_size cur_goal = F.normalize(mana_out, dim=-1) _real_goal = self.goal2goal(real_goal) # batch_size * goal_size - _real_goal = F.normalize(_real_goal, p=2, dim=-1).unsqueeze(-1) # batch_size * goal_size * 1 + _real_goal = F.normalize(_real_goal, p=2, dim=-1).unsqueeze( + -1 + ) # batch_size * goal_size * 1 # Worker - work_out, work_hidden = self.worker(emb, work_hidden) # work_out: 1 * batch_size * hidden_dim - work_out = self.work2goal(work_out).view(-1, self.vocab_size, - self.goal_size) # batch_size * vocab_size * goal_size + work_out, work_hidden = self.worker( + emb, work_hidden + ) # work_out: 1 * batch_size * hidden_dim + work_out = self.work2goal(work_out).view( + -1, self.vocab_size, self.goal_size + ) # batch_size * vocab_size * goal_size # Sample token out = torch.matmul(work_out, _real_goal).squeeze(-1) # batch_size * vocab_size @@ -101,21 +131,31 @@ def forward(self, idx, inp, work_hidden, mana_hidden, feature, real_goal, no_log return out, cur_goal, work_hidden, mana_hidden - def sample(self, num_samples, batch_size, dis, start_letter=cfg.start_letter, train=False): + def sample( + self, num_samples, batch_size, dis, start_letter=cfg.start_letter, train=False + ): """ Samples the network and returns num_samples samples of length max_seq_len. :return: samples: batch_size * max_seq_len """ num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1 - samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long() # larger than num_samples + samples = torch.zeros( + num_batch * batch_size, self.max_seq_len + ).long() # larger than num_samples fake_sentences = torch.zeros((batch_size, self.max_seq_len)) for b in range(num_batch): - leak_sample, _, _, _ = self.forward_leakgan(fake_sentences, dis, if_sample=True, no_log=False - , start_letter=start_letter, train=False) + leak_sample, _, _, _ = self.forward_leakgan( + fake_sentences, + dis, + if_sample=True, + no_log=False, + start_letter=start_letter, + train=False, + ) assert leak_sample.shape == (batch_size, self.max_seq_len) - samples[b * batch_size:(b + 1) * batch_size, :] = leak_sample + samples[b * batch_size : (b + 1) * batch_size, :] = leak_sample samples = samples[:num_samples, :] @@ -130,16 +170,22 @@ def pretrain_loss(self, target, dis, start_letter=cfg.start_letter): """ batch_size, seq_len = target.size() - _, feature_array, goal_array, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter) + _, feature_array, goal_array, leak_out_array = self.forward_leakgan( + target, dis, if_sample=False, no_log=False, start_letter=start_letter + ) # Manager loss - mana_cos_loss = self.manager_cos_loss(batch_size, feature_array, - goal_array) # batch_size * (seq_len / step_size) - manager_loss = -torch.sum(mana_cos_loss) / (batch_size * (seq_len // self.step_size)) + mana_cos_loss = self.manager_cos_loss( + batch_size, feature_array, goal_array + ) # batch_size * (seq_len / step_size) + manager_loss = -torch.sum(mana_cos_loss) / ( + batch_size * (seq_len // self.step_size) + ) # Worker loss - work_nll_loss = self.worker_nll_loss(target, leak_out_array) # batch_size * seq_len + work_nll_loss = self.worker_nll_loss( + target, leak_out_array + ) # batch_size * seq_len work_loss = torch.sum(work_nll_loss) / (batch_size * seq_len) return manager_loss, work_loss @@ -154,18 +200,31 @@ def adversarial_loss(self, target, rewards, dis, start_letter=cfg.start_letter): - rewards: batch_size * seq_len (discriminator rewards for each token) """ batch_size, seq_len = target.size() - _, feature_array, goal_array, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter, train=True) + _, feature_array, goal_array, leak_out_array = self.forward_leakgan( + target, + dis, + if_sample=False, + no_log=False, + start_letter=start_letter, + train=True, + ) # Manager Loss t0 = time.time() - mana_cos_loss = self.manager_cos_loss(batch_size, feature_array, - goal_array) # batch_size * (seq_len / step_size) - mana_loss = -torch.sum(rewards * mana_cos_loss) / (batch_size * (seq_len // self.step_size)) + mana_cos_loss = self.manager_cos_loss( + batch_size, feature_array, goal_array + ) # batch_size * (seq_len / step_size) + mana_loss = -torch.sum(rewards * mana_cos_loss) / ( + batch_size * (seq_len // self.step_size) + ) # Worker Loss - work_nll_loss = self.worker_nll_loss(target, leak_out_array) # batch_size * seq_len - work_cos_reward = self.worker_cos_reward(feature_array, goal_array) # batch_size * seq_len + work_nll_loss = self.worker_nll_loss( + target, leak_out_array + ) # batch_size * seq_len + work_cos_reward = self.worker_cos_reward( + feature_array, goal_array + ) # batch_size * seq_len work_loss = -torch.sum(work_nll_loss * work_cos_reward) / (batch_size * seq_len) return mana_loss, work_loss @@ -194,17 +253,23 @@ def manager_cos_loss(self, batch_size, feature_array, goal_array): # ===LeakGAN origin=== # get sub_feature and real_goal # batch_size, seq_len = sentences.size() - sub_feature = torch.zeros(batch_size, self.max_seq_len // self.step_size, self.goal_out_size) - real_goal = torch.zeros(batch_size, self.max_seq_len // self.step_size, self.goal_out_size) + sub_feature = torch.zeros( + batch_size, self.max_seq_len // self.step_size, self.goal_out_size + ) + real_goal = torch.zeros( + batch_size, self.max_seq_len // self.step_size, self.goal_out_size + ) for i in range(self.max_seq_len // self.step_size): idx = i * self.step_size - sub_feature[:, i, :] = feature_array[:, idx + self.step_size, :] - feature_array[:, idx, :] + sub_feature[:, i, :] = ( + feature_array[:, idx + self.step_size, :] - feature_array[:, idx, :] + ) if i == 0: real_goal[:, i, :] = self.goal_init[:batch_size, :] else: idx = (i - 1) * self.step_size + 1 - real_goal[:, i, :] = torch.sum(goal_array[:, idx:idx + 4, :], dim=1) + real_goal[:, i, :] = torch.sum(goal_array[:, idx : idx + 4, :], dim=1) # L2 noramlization sub_feature = F.normalize(sub_feature, p=2, dim=-1) @@ -220,7 +285,7 @@ def worker_nll_loss(self, target, leak_out_array): :return loss: batch_size * seq_len """ - loss_fn = nn.NLLLoss(reduction='none') + loss_fn = nn.NLLLoss(reduction="none") loss = loss_fn(leak_out_array.permute([0, 2, 1]), target) return loss @@ -232,26 +297,52 @@ def worker_cos_reward(self, feature_array, goal_array): :return: cos_loss: batch_size * seq_len """ for i in range(int(self.max_seq_len / self.step_size)): - real_feature = feature_array[:, i * self.step_size, :].unsqueeze(1).expand((-1, self.step_size, -1)) - feature_array[:, i * self.step_size:(i + 1) * self.step_size, :] = real_feature + real_feature = ( + feature_array[:, i * self.step_size, :] + .unsqueeze(1) + .expand((-1, self.step_size, -1)) + ) + feature_array[ + :, i * self.step_size : (i + 1) * self.step_size, : + ] = real_feature if i > 0: - sum_goal = torch.sum(goal_array[:, (i - 1) * self.step_size:i * self.step_size, :], dim=1, keepdim=True) + sum_goal = torch.sum( + goal_array[:, (i - 1) * self.step_size : i * self.step_size, :], + dim=1, + keepdim=True, + ) else: sum_goal = goal_array[:, 0, :].unsqueeze(1) - goal_array[:, i * self.step_size:(i + 1) * self.step_size, :] = sum_goal.expand((-1, self.step_size, -1)) - - offset_feature = feature_array[:, 1:, :] # f_{t+1}, batch_size * seq_len * goal_out_size - goal_array = goal_array[:, :self.max_seq_len, :] # batch_size * seq_len * goal_out_size + goal_array[ + :, i * self.step_size : (i + 1) * self.step_size, : + ] = sum_goal.expand((-1, self.step_size, -1)) + + offset_feature = feature_array[ + :, 1:, : + ] # f_{t+1}, batch_size * seq_len * goal_out_size + goal_array = goal_array[ + :, : self.max_seq_len, : + ] # batch_size * seq_len * goal_out_size sub_feature = offset_feature - goal_array # L2 normalization sub_feature = F.normalize(sub_feature, p=2, dim=-1) all_goal = F.normalize(goal_array, p=2, dim=-1) - cos_loss = F.cosine_similarity(sub_feature, all_goal, dim=-1) # batch_size * seq_len + cos_loss = F.cosine_similarity( + sub_feature, all_goal, dim=-1 + ) # batch_size * seq_len return cos_loss - def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter=cfg.start_letter, train=False): + def forward_leakgan( + self, + sentences, + dis, + if_sample, + no_log=False, + start_letter=cfg.start_letter, + train=False, + ): """ Get all feature and goals according to given sentences :param sentences: batch_size * max_seq_len, not include start token @@ -298,14 +389,24 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= if self.gpu: dis_inp = dis_inp.cuda() leak_inp = leak_inp.cuda() - feature = dis.get_feature(dis_inp).unsqueeze(0) # !!!note: 1 * batch_size * total_num_filters + feature = dis.get_feature(dis_inp).unsqueeze( + 0 + ) # !!!note: 1 * batch_size * total_num_filters feature_array[:, i, :] = feature.squeeze(0) # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.forward(i, leak_inp, work_hidden, mana_hidden, feature, - real_goal, no_log=no_log, train=train) + out, cur_goal, work_hidden, mana_hidden = self.forward( + i, + leak_inp, + work_hidden, + mana_hidden, + feature, + real_goal, + no_log=no_log, + train=train, + ) leak_out_array[:, i, :] = out # ===My implement according to paper=== @@ -320,14 +421,16 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= # Save goal and update real_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.goal_init[:batch_size, :] # Sample one token if not no_log: out = torch.exp(out) - out = torch.multinomial(out, 1).view(-1) # [batch_size] (sampling from each row) + out = torch.multinomial(out, 1).view( + -1 + ) # [batch_size] (sampling from each row) samples[:, i] = out.data leak_inp = out @@ -339,8 +442,9 @@ def forward_leakgan(self, sentences, dis, if_sample, no_log=False, start_letter= def batchNLLLoss(self, target, dis, start_letter=cfg.start_letter): # loss_fn = nn.NLLLoss() # batch_size, seq_len = target.size() - _, _, _, leak_out_array = self.forward_leakgan(target, dis, if_sample=False, no_log=False, - start_letter=start_letter) + _, _, _, leak_out_array = self.forward_leakgan( + target, dis, if_sample=False, no_log=False, start_letter=start_letter + ) nll_loss = torch.mean(self.worker_nll_loss(target, leak_out_array)) @@ -383,9 +487,9 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.gen_init == 'uniform': + if cfg.gen_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.gen_init == 'normal': + elif cfg.gen_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.gen_init == 'truncated_normal': + elif cfg.gen_init == "truncated_normal": truncated_normal_(param, std=stddev) diff --git a/models/generators/MaliGAN_G.py b/models/generators/MaliGAN_G.py index 35b64591..01b9a987 100644 --- a/models/generators/MaliGAN_G.py +++ b/models/generators/MaliGAN_G.py @@ -14,9 +14,13 @@ class MaliGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(MaliGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'maligan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(MaliGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "maligan" def adv_loss(self, inp, target, reward): """ @@ -31,8 +35,12 @@ def adv_loss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) diff --git a/models/generators/RelGAN_G.py b/models/generators/RelGAN_G.py index 1044b988..a856a453 100644 --- a/models/generators/RelGAN_G.py +++ b/models/generators/RelGAN_G.py @@ -16,15 +16,29 @@ class RelGAN_G(LSTMGenerator): - def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, - gpu=False): - super(RelGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'relgan' + def __init__( + self, + mem_slots, + num_heads, + head_size, + embedding_dim, + hidden_dim, + vocab_size, + max_seq_len, + padding_idx, + gpu=False, + ): + super(RelGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "relgan" self.temperature = 1.0 # init value is 1.0 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) - if cfg.model_type == 'LSTM': + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) + if cfg.model_type == "LSTM": # LSTM self.hidden_dim = hidden_dim self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, batch_first=True) @@ -32,15 +46,20 @@ def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, v else: # RMC self.hidden_dim = mem_slots * num_heads * head_size - self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim, - num_heads=num_heads, return_all_outputs=True) + self.lstm = RelationalMemory( + mem_slots=mem_slots, + head_size=head_size, + input_size=embedding_dim, + num_heads=num_heads, + return_all_outputs=True, + ) self.lstm2out = nn.Linear(self.hidden_dim, vocab_size) self.init_params() pass def init_hidden(self, batch_size=cfg.batch_size): - if cfg.model_type == 'LSTM': + if cfg.model_type == "LSTM": h = torch.zeros(1, batch_size, self.hidden_dim) c = torch.zeros(1, batch_size, self.hidden_dim) @@ -79,7 +98,9 @@ def step(self, inp, hidden): return pred, hidden, next_token, next_token_onehot, next_o - def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter): + def sample( + self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter + ): """ Sample from RelGAN Generator - one_hot: if return pred of RelGAN, used for adversarial training @@ -103,7 +124,7 @@ def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_ for i in range(self.max_seq_len): pred, hidden, next_token, _, _ = self.step(inp, hidden) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token + samples[b * batch_size : (b + 1) * batch_size, i] = next_token if one_hot: all_preds[:, i] = pred inp = next_token diff --git a/models/generators/SentiGAN_G.py b/models/generators/SentiGAN_G.py index a7c8f360..145c4750 100644 --- a/models/generators/SentiGAN_G.py +++ b/models/generators/SentiGAN_G.py @@ -15,9 +15,13 @@ class SentiGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(SentiGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'sentigan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(SentiGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "sentigan" def forward(self, inp, hidden, need_hidden=False, use_log=True): """ @@ -31,7 +35,9 @@ def forward(self, inp, hidden, need_hidden=False, use_log=True): emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # batch_size * seq_len * vocab_size # out = self.temperature * out # temperature if use_log: @@ -57,8 +63,12 @@ def batchPGLoss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden, use_log=False).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * (1 - reward)) diff --git a/models/generators/SeqGAN_G.py b/models/generators/SeqGAN_G.py index 0414d816..9859505b 100644 --- a/models/generators/SeqGAN_G.py +++ b/models/generators/SeqGAN_G.py @@ -14,9 +14,13 @@ class SeqGAN_G(LSTMGenerator): - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): - super(SeqGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu) - self.name = 'seqgan' + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): + super(SeqGAN_G, self).__init__( + embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu + ) + self.name = "seqgan" def batchPGLoss(self, inp, target, reward): """ @@ -31,8 +35,12 @@ def batchPGLoss(self, inp, target, reward): batch_size, seq_len = inp.size() hidden = self.init_hidden(batch_size) - out = self.forward(inp, hidden).view(batch_size, self.max_seq_len, self.vocab_size) - target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size + out = self.forward(inp, hidden).view( + batch_size, self.max_seq_len, self.vocab_size + ) + target_onehot = F.one_hot( + target, self.vocab_size + ).float() # batch_size * seq_len * vocab_size pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len loss = -torch.sum(pred * reward) diff --git a/models/generators/generator.py b/models/generators/generator.py index 8db32e26..70401242 100644 --- a/models/generators/generator.py +++ b/models/generators/generator.py @@ -16,10 +16,11 @@ class LSTMGenerator(nn.Module): - - def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False): + def __init__( + self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False + ): super(LSTMGenerator, self).__init__() - self.name = 'vanilla' + self.name = "vanilla" self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -30,7 +31,9 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_i self.temperature = 1.0 - self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) + self.embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=padding_idx + ) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.lstm2out = nn.Linear(hidden_dim, vocab_size) self.softmax = nn.LogSoftmax(dim=-1) @@ -49,7 +52,9 @@ def forward(self, inp, hidden, need_hidden=False): emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim - out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim + out = out.contiguous().view( + -1, self.hidden_dim + ) # out: (batch_size * len) * hidden_dim out = self.lstm2out(out) # (batch_size * seq_len) * vocab_size # out = self.temperature * out # temperature pred = self.softmax(out) @@ -75,9 +80,13 @@ def sample(self, num_samples, batch_size, start_letter=cfg.start_letter): inp = inp.cuda() for i in range(self.max_seq_len): - out, hidden = self.forward(inp, hidden, need_hidden=True) # out: batch_size * vocab_size - next_token = torch.multinomial(torch.exp(out), 1, replacement=True) # batch_size * 1 (sampling from each row) - samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1) + out, hidden = self.forward( + inp, hidden, need_hidden=True + ) # out: batch_size * vocab_size + next_token = torch.multinomial( + torch.exp(out), 1, replacement=True + ) # batch_size * 1 (sampling from each row) + samples[b * batch_size : (b + 1) * batch_size, i] = next_token.view(-1) inp = next_token.view(-1) samples = samples[:num_samples] @@ -87,11 +96,11 @@ def init_params(self): for param in self.parameters(): if param.requires_grad and len(param.shape) > 0: stddev = 1 / math.sqrt(param.shape[0]) - if cfg.gen_init == 'uniform': + if cfg.gen_init == "uniform": torch.nn.init.uniform_(param, a=-0.05, b=0.05) - elif cfg.gen_init == 'normal': + elif cfg.gen_init == "normal": torch.nn.init.normal_(param, std=stddev) - elif cfg.gen_init == 'truncated_normal': + elif cfg.gen_init == "truncated_normal": truncated_normal_(param, std=stddev) def init_hidden(self, batch_size=cfg.batch_size): diff --git a/models/relational_rnn_general.py b/models/relational_rnn_general.py index b4b02b21..c12fbfe5 100644 --- a/models/relational_rnn_general.py +++ b/models/relational_rnn_general.py @@ -38,8 +38,20 @@ class RelationalMemory(nn.Module): ValueError: attention_mlp_layers is < 1. """ - def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, forget_bias=1., input_bias=0., - gate_style='unit', attention_mlp_layers=2, key_size=None, return_all_outputs=False): + def __init__( + self, + mem_slots, + head_size, + input_size, + num_heads=1, + num_blocks=1, + forget_bias=1.0, + input_bias=0.0, + gate_style="unit", + attention_mlp_layers=2, + key_size=None, + return_all_outputs=False, + ): super(RelationalMemory, self).__init__() ########## generic parameters for RMC ########## @@ -54,18 +66,22 @@ def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, self.mem_slots_plus_input = self.mem_slots + 1 if num_blocks < 1: - raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks)) + raise ValueError("num_blocks must be >=1. Got: {}.".format(num_blocks)) self.num_blocks = num_blocks - if gate_style not in ['unit', 'memory', None]: + if gate_style not in ["unit", "memory", None]: raise ValueError( - 'gate_style must be one of [\'unit\', \'memory\', None]. got: ' - '{}.'.format(gate_style)) + "gate_style must be one of ['unit', 'memory', None]. got: " + "{}.".format(gate_style) + ) self.gate_style = gate_style if attention_mlp_layers < 1: - raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format( - attention_mlp_layers)) + raise ValueError( + "attention_mlp_layers must be >= 1. Got: {}.".format( + attention_mlp_layers + ) + ) self.attention_mlp_layers = attention_mlp_layers self.key_size = key_size if key_size else self.head_size @@ -81,12 +97,20 @@ def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, # just using one big param is more efficient, rather than this line # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)] self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size) - self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size]) + self.qkv_layernorm = nn.LayerNorm( + [self.mem_slots_plus_input, self.total_qkv_size] + ) # used for attend_over_memory function - self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers) - self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) - self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) + self.attention_mlp = nn.ModuleList( + [nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers + ) + self.attended_memory_layernorm = nn.LayerNorm( + [self.mem_slots_plus_input, self.mem_size] + ) + self.attended_memory_layernorm2 = nn.LayerNorm( + [self.mem_slots_plus_input, self.mem_size] + ) ########## parameters for initial embedded input projection ########## self.input_size = input_size @@ -135,7 +159,7 @@ def initial_state(self, batch_size, trainable=False): # truncation. take the first 'self.mem_size' components elif self.mem_size < self.mem_slots: - init_state = init_state[:, :, :self.mem_size] + init_state = init_state[:, :, : self.mem_size] return init_state @@ -168,10 +192,12 @@ def multihead_attention(self, memory): qkv_transpose = qkv_reshape.permute(0, 2, 1, 3) # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size] - q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1) + q, k, v = torch.split( + qkv_transpose, [self.key_size, self.key_size, self.value_size], -1 + ) # scale q with d_k, the dimensionality of the key vectors - q = q * (self.key_size ** -0.5) + q = q * (self.key_size**-0.5) # make it [B, H, N, N] dot_product = torch.matmul(q, k.permute(0, 1, 3, 2)) @@ -182,7 +208,9 @@ def multihead_attention(self, memory): # [B, H, N, V] => [B, N, H, V] => [B, N, H*V] output_transpose = output.permute(0, 2, 1, 3).contiguous() - new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1)) + new_memory = output_transpose.view( + (output_transpose.shape[0], output_transpose.shape[1], -1) + ) return new_memory @@ -200,9 +228,9 @@ def calculate_gate_size(self): Returns: The per sample, per head parameter size of each gate. """ - if self.gate_style == 'unit': + if self.gate_style == "unit": return self.mem_size - elif self.gate_style == 'memory': + elif self.gate_style == "memory": return 1 else: # self.gate_style == None return 0 @@ -231,7 +259,8 @@ def create_gates(self, inputs, memory): if len(inputs.shape) == 3: if inputs.shape[1] > 1: raise ValueError( - "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1") + "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1" + ) inputs = inputs.view(inputs.shape[0], -1) # matmul for equation 4 and 5 # there is no output gate, so equation 6 is not implemented @@ -243,7 +272,9 @@ def create_gates(self, inputs, memory): # this completes the equation 4 and 5 gates = gate_memory + gate_inputs - gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2) + gates = torch.split( + gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2 + ) input_gate, forget_gate = gates assert input_gate.shape[2] == forget_gate.shape[2] @@ -310,7 +341,7 @@ def forward_step(self, inputs, memory, treat_input_as_matrix=False): n = inputs_reshape.shape[1] next_memory = next_memory[:, :-n, :] - if self.gate_style == 'unit' or self.gate_style == 'memory': + if self.gate_style == "unit" or self.gate_style == "memory": # these gates are sigmoid-applied ones for equation 7 input_gate, forget_gate = self.create_gates(inputs_reshape, memory) # equation 7 calculation @@ -345,6 +376,7 @@ def forward(self, inputs, memory, treat_input_as_matrix=False): else: return logit.unsqueeze(1), memory + # ########## DEBUG: unit test code ########## # input_size = 32 # seq_length = 20 diff --git a/run/run_catgan.py b/run/run_catgan.py index e86b9619..71e4b78e 100644 --- a/run/run_catgan.py +++ b/run/run_catgan.py @@ -15,26 +15,30 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' -rootdir = '../' -scriptname = 'main.py' +executable = "python" +rootdir = "../" +scriptname = "main.py" # ===Program=== # CatGAN: Catgory text generation model # EvoGAN: General text generation model if_test = int(False) -run_model = ['evogan', 'catgan', 'catgan', 'evogan', 'evogan', 'evogan'] +run_model = ["evogan", "catgan", "catgan", "evogan", "evogan", "evogan"] k_label = 2 CUDA = int(True) ora_pretrain = int(True) @@ -43,19 +47,19 @@ MLE_train_epoch = 150 clas_pre_epoch = 5 ADV_train_epoch = 2000 -tips = '{} experiments' +tips = "{} experiments" # ===Oracle or Real=== if_real_data = [int(True), int(False), int(True), int(False), int(True), int(True)] -dataset = ['amazon_app_book', 'oracle', 'mr15', 'oracle', 'image_coco', 'emnlp_news'] +dataset = ["amazon_app_book", "oracle", "mr15", "oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0, 5000, 0, 0] # ===CatGAN Param=== n_parent = 1 -loss_type = 'ragan' -mu_type = 'ragan rsgan' -eval_type = 'Ra' -temp_adpt = 'exp' +loss_type = "ragan" +mu_type = "ragan rsgan" +eval_type = "Ra" +temp_adpt = "exp" temperature = [1, 100, 100, 1, 100, 100] d_out_mean = int(True) lambda_fq = 1.0 @@ -64,9 +68,9 @@ # === Basic Param === data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "truncated_normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -100,71 +104,117 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model[job_id], - '--k_label', k_label, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model[job_id], + "--k_label", + k_label, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', ora_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--clas_pre_epoch', clas_pre_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips.format(run_model[job_id]), - + "--ora_pretrain", + ora_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--clas_pre_epoch", + clas_pre_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips.format(run_model[job_id]), # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # CatGAN Param - '--n_parent', n_parent, - '--loss_type', loss_type, - '--mu_type', mu_type, - '--eval_type', eval_type, - '--temp_adpt', temp_adpt, - '--temperature', temperature[job_id], - '--d_out_mean', d_out_mean, - '--lambda_fq', lambda_fq, - '--lambda_fd', lambda_fd, - '--eval_b_num', eval_b_num, - + "--n_parent", + n_parent, + "--loss_type", + loss_type, + "--mu_type", + mu_type, + "--eval_type", + eval_type, + "--temp_adpt", + temp_adpt, + "--temperature", + temperature[job_id], + "--d_out_mean", + d_out_mean, + "--lambda_fq", + lambda_fq, + "--lambda_fd", + lambda_fd, + "--eval_b_num", + eval_b_num, # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--gen_adv_lr', gen_adv_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--gen_adv_lr", + gen_adv_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--mem_slots', mem_slots, - '--num_heads', num_heads, - '--head_size', head_size[job_id], - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--mem_slots", + mem_slots, + "--num_heads", + num_heads, + "--head_size", + head_size[job_id], # Discriminator - '--adv_d_step', ADV_d_step, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - '--num_rep', num_rep, - + "--adv_d_step", + ADV_d_step, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, + "--num_rep", + num_rep, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_clas_acc', use_clas_acc, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_cot.py b/run/run_cot.py index 491f984f..3f0661db 100644 --- a/run/run_cot.py +++ b/run/run_cot.py @@ -16,42 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'cot' +run_model = "cot" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 0 ADV_train_epoch = 20000 -tips = 'CoT experiments' +tips = "CoT experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'normal' +model_type = "vanilla" +gen_init = "normal" +dis_init = "normal" batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -76,48 +80,74 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--adv_d_step', ADV_d_step, - + "--adv_d_step", + ADV_d_step, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_dgsan.py b/run/run_dgsan.py index f096d87d..33bc078a 100644 --- a/run/run_dgsan.py +++ b/run/run_dgsan.py @@ -16,41 +16,45 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = '/home/zhiwei/.virtualenvs/zhiwei/bin/python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "/home/zhiwei/.virtualenvs/zhiwei/bin/python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'dgsan' +run_model = "dgsan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 0 ADV_train_epoch = 200 -tips = 'DGSAN experiments' +tips = "DGSAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' +model_type = "vanilla" +gen_init = "truncated_normal" batch_size = 64 max_seq_len = 20 gen_lr = 1e-2 @@ -71,43 +75,67 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_dpgan.py b/run/run_dpgan.py index fe40c1e5..7b7496e1 100644 --- a/run/run_dpgan.py +++ b/run/run_dpgan.py @@ -16,42 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 2 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = '/home/zhiwei/.virtualenvs/zhiwei/bin/python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "/home/zhiwei/.virtualenvs/zhiwei/bin/python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'dpgan' +run_model = "dpgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 120 ADV_train_epoch = 200 -tips = 'DPGAN experiments' +tips = "DPGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -83,55 +87,89 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, - '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, + "--device", + gpu_id, # comment for auto GPU + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_fixem.py b/run/run_fixem.py index defa416a..9650f93a 100644 --- a/run/run_fixem.py +++ b/run/run_fixem.py @@ -15,37 +15,71 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' -rootdir = '../' -scriptname = 'main.py' +executable = "python" +rootdir = "../" +scriptname = "main.py" # ===Program=== # FixemGAN: General text generation model if_test = int(False) -run_model = ['fixemgan', 'cat_fixemgan', 'fixemgan', 'cat_fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'fixemgan', 'cat_fixemgan'] +run_model = [ + "fixemgan", + "cat_fixemgan", + "fixemgan", + "cat_fixemgan", + "fixemgan", + "fixemgan", + "fixemgan", + "fixemgan", + "cat_fixemgan", +] k_label = 2 CUDA = int(True) noise_size = 1000 max_epochs = 20 batches_per_epoch = 200 -samples_num = 100 # sample for metrics -tips = '{} experiments' +samples_num = 100 # sample for metrics +tips = "{} experiments" # ===Oracle or Real=== -if_real_data = [int(True), int(True), int(True), int(True), int(True), int(True), int(True), int(False), int(False)] -dataset = ['image_coco', 'mr20', 'mr20', 'mr15', 'mr15', 'amazon_app_book', 'emnlp_news', 'oracle', 'oracle'] -w2v_embedding_size = 512 #low on ram #hyperparam +if_real_data = [ + int(True), + int(True), + int(True), + int(True), + int(True), + int(True), + int(True), + int(False), + int(False), +] +dataset = [ + "image_coco", + "mr20", + "mr20", + "mr15", + "mr15", + "amazon_app_book", + "emnlp_news", + "oracle", + "oracle", +] +w2v_embedding_size = 512 # low on ram #hyperparam w2v_window = 5 w2v_min_count = 30 w2v_workers = 30 @@ -54,21 +88,31 @@ # === Basic Param === data_shuffle = int(False) -model_type = 'fixem' -loss_type = 'fixem' -gen_init = 'truncated_normal' -dis_init = 'uniform' +model_type = "fixem" +loss_type = "fixem" +gen_init = "truncated_normal" +dis_init = "uniform" batch_size = 64 -target_len = [16, 20, 20, 16, 16, 40, 48, 20, 20] # architechture requires to be divisible by 4 +target_len = [ + 16, + 20, + 20, + 16, + 16, + 40, + 48, + 20, + 20, +] # architechture requires to be divisible by 4 real_fake_coeff = 1.0 labels_coeff = 1.0 diversity_coeff = 1.0 # ===Generator=== -generator_complexity = 768 #hyperparam +generator_complexity = 768 # hyperparam # ===Discriminator=== -discriminator_complexity = 512 #hyperparam +discriminator_complexity = 512 # hyperparam # ===Metrics=== use_nll_oracle = int(True) @@ -81,50 +125,75 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model[job_id], - '--k_label', k_label, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model[job_id], + "--k_label", + k_label, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--tips', tips.format(run_model[job_id]), - + "--tips", + tips.format(run_model[job_id]), # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size, - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size, # W2V embeddings - '--w2v_embedding_size', w2v_embedding_size, - '--w2v_window', w2v_window, - '--w2v_min_count', w2v_min_count, - '--w2v_workers', w2v_workers, - '--w2v_samples_num', w2v_samples_num, - + "--w2v_embedding_size", + w2v_embedding_size, + "--w2v_window", + w2v_window, + "--w2v_min_count", + w2v_min_count, + "--w2v_workers", + w2v_workers, + "--w2v_samples_num", + w2v_samples_num, # FixemGAN Param - '--loss_type', loss_type, - '--max_epochs', max_epochs, - '--batches_per_epoch', batches_per_epoch, - '--noise_size', noise_size, - '--target_len', target_len[job_id], - '--batch_size', batch_size, - '--real_fake_coeff', real_fake_coeff, - '--labels_coeff', labels_coeff, - '--diversity_coeff', diversity_coeff, - + "--loss_type", + loss_type, + "--max_epochs", + max_epochs, + "--batches_per_epoch", + batches_per_epoch, + "--noise_size", + noise_size, + "--target_len", + target_len[job_id], + "--batch_size", + batch_size, + "--real_fake_coeff", + real_fake_coeff, + "--labels_coeff", + labels_coeff, + "--diversity_coeff", + diversity_coeff, # Generator - '--generator_complexity', generator_complexity, - + "--generator_complexity", + generator_complexity, # Discriminator - '--discriminator_complexity', discriminator_complexity, - + "--discriminator_complexity", + discriminator_complexity, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_clas_acc', use_clas_acc, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_jsdgan.py b/run/run_jsdgan.py index 9b1ee7d5..60c88964 100644 --- a/run/run_jsdgan.py +++ b/run/run_jsdgan.py @@ -16,40 +16,44 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'jsdgan' +run_model = "jsdgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) MLE_train_epoch = 0 # no pre-training ADV_train_epoch = 500 -tips = 'JSDGAN experiments' +tips = "JSDGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' +model_type = "vanilla" +gen_init = "normal" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -71,43 +75,67 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_leakgan.py b/run/run_leakgan.py index 327719c9..507d4057 100644 --- a/run/run_leakgan.py +++ b/run/run_leakgan.py @@ -16,24 +16,28 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'leakgan' +run_model = "leakgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) @@ -41,18 +45,18 @@ MLE_train_epoch = 8 ADV_train_epoch = 200 inter_epoch = 10 -tips = 'LeakGAN experiments' +tips = "LeakGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.0015 @@ -86,58 +90,94 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--inter_epoch', inter_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--inter_epoch", + inter_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--goal_size', goal_size, - '--step_size', step_size, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--goal_size", + goal_size, + "--step_size", + step_size, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_maligan.py b/run/run_maligan.py index 22b021e9..20b7f49f 100644 --- a/run/run_maligan.py +++ b/run/run_maligan.py @@ -16,42 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'maligan' +run_model = "maligan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 80 ADV_train_epoch = 200 -tips = 'MaliGAN experiments' +tips = "MaliGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -83,55 +87,88 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step[job_id], - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step[job_id], + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_relgan.py b/run/run_relgan.py index 72cff2ba..20f3e649 100644 --- a/run/run_relgan.py +++ b/run/run_relgan.py @@ -4,7 +4,7 @@ # @FileName : run_relgan.py # @Time : Created at 2019-05-28 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import sys @@ -16,45 +16,49 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'relgan' +run_model = "relgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 150 ADV_train_epoch = 3000 -tips = 'RelGAN experiments' +tips = "RelGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] -loss_type = 'rsgan' +dataset = ["oracle", "image_coco", "emnlp_news"] +loss_type = "rsgan" vocab_size = [5000, 0, 0] -temp_adpt = 'exp' +temp_adpt = "exp" temperature = [1, 100, 100] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'truncated_normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "truncated_normal" +dis_init = "uniform" samples_num = 10000 batch_size = 64 max_seq_len = 20 @@ -88,60 +92,98 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--loss_type', loss_type, - '--vocab_size', vocab_size[job_id], - '--temp_adpt', temp_adpt, - '--temperature', temperature[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--loss_type", + loss_type, + "--vocab_size", + vocab_size[job_id], + "--temp_adpt", + temp_adpt, + "--temperature", + temperature[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--samples_num', samples_num, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--gen_adv_lr', gen_adv_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--samples_num", + samples_num, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--gen_adv_lr", + gen_adv_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - '--mem_slots', mem_slots, - '--num_heads', num_heads, - '--head_size', head_size, - + "--adv_g_step", + ADV_g_step, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, + "--mem_slots", + mem_slots, + "--num_heads", + num_heads, + "--head_size", + head_size, # Discriminator - '--adv_d_step', ADV_d_step, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - '--num_rep', num_rep, - + "--adv_d_step", + ADV_d_step, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, + "--num_rep", + num_rep, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_sentigan.py b/run/run_sentigan.py index 1be52a6b..8e5f6252 100644 --- a/run/run_sentigan.py +++ b/run/run_sentigan.py @@ -17,24 +17,28 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'sentigan' +run_model = "sentigan" k_label = 2 CUDA = int(True) oracle_pretrain = int(True) @@ -43,18 +47,18 @@ MLE_train_epoch = 120 clas_pre_epoch = 5 ADV_train_epoch = 100 -tips = 'SentiGAN experiments' +tips = "SentiGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'mr15', 'amazon_app_book'] +dataset = ["oracle", "mr15", "amazon_app_book"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -87,58 +91,94 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--k_label', k_label, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--k_label", + k_label, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--clas_pre_epoch', clas_pre_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--clas_pre_epoch", + clas_pre_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_clas_acc', use_clas_acc, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_clas_acc", + use_clas_acc, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/run/run_seqgan.py b/run/run_seqgan.py index 4a6f00da..701d9364 100644 --- a/run/run_seqgan.py +++ b/run/run_seqgan.py @@ -16,42 +16,46 @@ if len(sys.argv) > 2: job_id = int(sys.argv[1]) gpu_id = str(sys.argv[2]) - print('job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print("job_id: {}, gpu_id: {}".format(job_id, gpu_id)) elif len(sys.argv) > 1: job_id = int(sys.argv[1]) gpu_id = 0 - print('job_id: {}, missing gpu_id (use default {})'.format(job_id, gpu_id)) + print("job_id: {}, missing gpu_id (use default {})".format(job_id, gpu_id)) else: job_id = 0 gpu_id = 0 - print('Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}'.format(job_id, gpu_id)) + print( + "Missing argument: job_id and gpu_id. Use default job_id: {}, gpu_id: {}".format( + job_id, gpu_id + ) + ) # Executables -executable = 'python' # specify your own python interpreter path here -rootdir = '../' -scriptname = 'main.py' +executable = "python" # specify your own python interpreter path here +rootdir = "../" +scriptname = "main.py" # ===Program=== if_test = int(False) -run_model = 'seqgan' +run_model = "seqgan" CUDA = int(True) oracle_pretrain = int(True) gen_pretrain = int(False) dis_pretrain = int(False) MLE_train_epoch = 120 ADV_train_epoch = 200 -tips = 'SeqGAN experiments' +tips = "SeqGAN experiments" # ===Oracle or Real=== if_real_data = [int(False), int(True), int(True)] -dataset = ['oracle', 'image_coco', 'emnlp_news'] +dataset = ["oracle", "image_coco", "emnlp_news"] vocab_size = [5000, 0, 0] # ===Basic Param=== data_shuffle = int(False) -model_type = 'vanilla' -gen_init = 'normal' -dis_init = 'uniform' +model_type = "vanilla" +gen_init = "normal" +dis_init = "uniform" batch_size = 64 max_seq_len = 20 gen_lr = 0.01 @@ -83,55 +87,88 @@ args = [ # Program - '--if_test', if_test, - '--run_model', run_model, - '--cuda', CUDA, + "--if_test", + if_test, + "--run_model", + run_model, + "--cuda", + CUDA, # '--device', gpu_id, # comment for auto GPU - '--ora_pretrain', oracle_pretrain, - '--gen_pretrain', gen_pretrain, - '--dis_pretrain', dis_pretrain, - '--mle_epoch', MLE_train_epoch, - '--adv_epoch', ADV_train_epoch, - '--tips', tips, - + "--ora_pretrain", + oracle_pretrain, + "--gen_pretrain", + gen_pretrain, + "--dis_pretrain", + dis_pretrain, + "--mle_epoch", + MLE_train_epoch, + "--adv_epoch", + ADV_train_epoch, + "--tips", + tips, # Oracle or Real - '--if_real_data', if_real_data[job_id], - '--dataset', dataset[job_id], - '--vocab_size', vocab_size[job_id], - + "--if_real_data", + if_real_data[job_id], + "--dataset", + dataset[job_id], + "--vocab_size", + vocab_size[job_id], # Basic Param - '--shuffle', data_shuffle, - '--model_type', model_type, - '--gen_init', gen_init, - '--dis_init', dis_init, - '--batch_size', batch_size, - '--max_seq_len', max_seq_len, - '--gen_lr', gen_lr, - '--dis_lr', dis_lr, - '--pre_log_step', pre_log_step, - '--adv_log_step', adv_log_step, - + "--shuffle", + data_shuffle, + "--model_type", + model_type, + "--gen_init", + gen_init, + "--dis_init", + dis_init, + "--batch_size", + batch_size, + "--max_seq_len", + max_seq_len, + "--gen_lr", + gen_lr, + "--dis_lr", + dis_lr, + "--pre_log_step", + pre_log_step, + "--adv_log_step", + adv_log_step, # Generator - '--adv_g_step', ADV_g_step, - '--rollout_num', rollout_num, - '--gen_embed_dim', gen_embed_dim, - '--gen_hidden_dim', gen_hidden_dim, - + "--adv_g_step", + ADV_g_step, + "--rollout_num", + rollout_num, + "--gen_embed_dim", + gen_embed_dim, + "--gen_hidden_dim", + gen_hidden_dim, # Discriminator - '--d_step', d_step, - '--d_epoch', d_epoch, - '--adv_d_step', ADV_d_step, - '--adv_d_epoch', ADV_d_epoch, - '--dis_embed_dim', dis_embed_dim, - '--dis_hidden_dim', dis_hidden_dim, - + "--d_step", + d_step, + "--d_epoch", + d_epoch, + "--adv_d_step", + ADV_d_step, + "--adv_d_epoch", + ADV_d_epoch, + "--dis_embed_dim", + dis_embed_dim, + "--dis_hidden_dim", + dis_hidden_dim, # Metrics - '--use_nll_oracle', use_nll_oracle, - '--use_nll_gen', use_nll_gen, - '--use_nll_div', use_nll_div, - '--use_bleu', use_bleu, - '--use_self_bleu', use_self_bleu, - '--use_ppl', use_ppl, + "--use_nll_oracle", + use_nll_oracle, + "--use_nll_gen", + use_nll_gen, + "--use_nll_div", + use_nll_div, + "--use_bleu", + use_bleu, + "--use_self_bleu", + use_self_bleu, + "--use_ppl", + use_ppl, ] args = list(map(str, args)) diff --git a/utils/cat_data_loader.py b/utils/cat_data_loader.py index 07721db5..2881fcea 100644 --- a/utils/cat_data_loader.py +++ b/utils/cat_data_loader.py @@ -37,18 +37,24 @@ def __init__(self, samples_list, shuffle=None): dataset=GANDataset(self.__read_data__(samples_list)), batch_size=self.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') - self.label = self._all_data_('label') # from 0 to k-1, different from Discriminator label + self.input = self._all_data_("input") + self.target = self._all_data_("target") + self.label = self._all_data_( + "label" + ) # from 0 to k-1, different from Discriminator label def __read_data__(self, samples_list): """ input: same as target, but start with start_letter. """ inp, target, label = self.prepare(samples_list) - all_data = [{'input': i, 'target': t, 'label': l} for (i, t, l) in zip(inp, target, label)] + all_data = [ + {"input": i, "target": t, "label": l} + for (i, t, l) in zip(inp, target, label) + ] return all_data def random_batch(self): @@ -57,7 +63,9 @@ def random_batch(self): return list(self.loader)[idx] def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) def prepare(self, samples_list, gpu=False): """Add start_letter to samples as inp, target same as samples""" @@ -65,12 +73,12 @@ def prepare(self, samples_list, gpu=False): target = all_samples inp = torch.zeros(all_samples.size()).long() inp[:, 0] = self.start_letter - inp[:, 1:] = target[:, :self.max_seq_len - 1] + inp[:, 1:] = target[:, : self.max_seq_len - 1] label = torch.zeros(all_samples.size(0)).long() for idx in range(len(samples_list)): start = sum([samples_list[i].size(0) for i in range(idx)]) - label[start: start + samples_list[idx].size(0)] = idx + label[start : start + samples_list[idx].size(0)] = idx # shuffle perm = torch.randperm(inp.size(0)) @@ -105,14 +113,15 @@ def __init__(self, samples_list, given_target=None, shuffle=None): dataset=GANDataset(self.__read_data__(samples_list, given_target)), batch_size=self.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') + self.input = self._all_data_("input") + self.target = self._all_data_("target") def __read_data__(self, samples_list, given_target=None): inp, target = self.prepare(samples_list, given_target) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -121,7 +130,9 @@ def random_batch(self): # return next(iter(self.loader)) def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) @staticmethod def prepare(samples_list, given_target=None, detach=True, gpu=False): @@ -135,7 +146,7 @@ def prepare(samples_list, given_target=None, detach=True, gpu=False): - inp: sentences - target: label index, 0-label_0, 1-label_1, ..., k-label_k """ - if type(samples_list[0][0][0]) == str: # directly generated text + if type(samples_list[0][0][0]) == str: # directly generated text inp = torch.zeros(1) target = torch.zeros(1) elif len(samples_list) == 1 and given_target is not None: @@ -154,7 +165,7 @@ def prepare(samples_list, given_target=None, detach=True, gpu=False): inp = inp.long() for idx in range(1, len(samples_list)): start = sum([samples_list[i].size(0) for i in range(idx)]) - target[start: start + samples_list[idx].size(0)] = idx + target[start : start + samples_list[idx].size(0)] = idx # shuffle perm = torch.randperm(inp.size(0)) diff --git a/utils/create_embeddings.py b/utils/create_embeddings.py index 4145553c..2ebe9a39 100644 --- a/utils/create_embeddings.py +++ b/utils/create_embeddings.py @@ -11,7 +11,7 @@ def __init__(self, files): self.files = files def __iter__(self): - for file in tqdm(self.files, desc='iterating files'): + for file in tqdm(self.files, desc="iterating files"): for tokens in text_file_iterator(file): yield [cfg.padding_token] * 5 + tokens diff --git a/utils/data_loader.py b/utils/data_loader.py index e2e9333a..d67fd926 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -38,7 +38,7 @@ def __len__(self): class GenDataIter: def __init__(self, samples, if_test_data=False, shuffle=None): self.samples = samples - if type(self.samples) == str: # we received filename + if type(self.samples) == str: # we received filename self.samples = get_tokenlized(self.samples) self.shuffle = cfg.data_shuffle if not shuffle else shuffle @@ -52,26 +52,29 @@ def __init__(self, samples, if_test_data=False, shuffle=None): dataset=GANDataset(self.__read_data__(self.samples)), batch_size=cfg.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) - self.input = self._all_data_('input') - self.target = self._all_data_('target') + self.input = self._all_data_("input") + self.target = self._all_data_("target") def __read_data__(self, samples): """ input: same as target, but start with start_letter. """ - if isinstance(samples[0], str) or isinstance(samples[0][0], str): # list of strings + if isinstance(samples[0], str) or isinstance( + samples[0][0], str + ): # list of strings # we directly generated string, skip NLL return [ - {'input': i, 'target': t} - for i, t in zip(torch.zeros(2), torch.zeros(2)) + {"input": i, "target": t} + for i, t in zip(torch.zeros(2), torch.zeros(2)) ] if isinstance(samples[0], list): # need to transform to indexes samples = tokens_to_tensor(samples, self.word2idx_dict) inp, target = self.prepare_for_NLL(samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -80,13 +83,15 @@ def random_batch(self): return list(self.loader)[idx] def _all_data_(self, col): - return torch.cat([data[col].unsqueeze(0) for data in self.loader.dataset.data], 0) + return torch.cat( + [data[col].unsqueeze(0) for data in self.loader.dataset.data], 0 + ) @property def tokens(self): """Returns samples in form of list of tensors, if input tensor, or list of tokens in case if input string.""" - if type(self.samples[0]) == str: # we have list of strings + if type(self.samples[0]) == str: # we have list of strings return [smpl.split() for smpl in self.samples] return list(self.samples) @@ -96,7 +101,7 @@ def prepare_for_NLL(samples, gpu=False): inp = torch.zeros(samples.size()).long() target = samples inp[:, 0] = cfg.start_letter - inp[:, 1:] = target[:, :cfg.max_seq_len - 1] + inp[:, 1:] = target[:, : cfg.max_seq_len - 1] if gpu: return inp.cuda(), target.cuda() @@ -111,11 +116,12 @@ def __init__(self, pos_samples, neg_samples, shuffle=None): dataset=GANDataset(self.__read_data__(pos_samples, neg_samples)), batch_size=cfg.batch_size, shuffle=self.shuffle, - drop_last=True) + drop_last=True, + ) def __read_data__(self, pos_samples, neg_samples): inp, target = self.prepare(pos_samples, neg_samples) - all_data = [{'input': i, 'target': t} for (i, t) in zip(inp, target)] + all_data = [{"input": i, "target": t} for (i, t) in zip(inp, target)] return all_data def random_batch(self): @@ -124,9 +130,11 @@ def random_batch(self): def prepare(self, pos_samples, neg_samples, gpu=False): """Build inp and target""" - inp = torch.cat((pos_samples, neg_samples), dim=0).long().detach() # !!!need .detach() + inp = ( + torch.cat((pos_samples, neg_samples), dim=0).long().detach() + ) # !!!need .detach() target = torch.ones(inp.size(0)).long() - target[pos_samples.size(0):] = 0 + target[pos_samples.size(0) :] = 0 # shuffle perm = torch.randperm(inp.size(0)) @@ -140,18 +148,23 @@ def prepare(self, pos_samples, neg_samples, gpu=False): class DataSupplier: def __init__(self, tokenized, labels, w2v, batch_size, batches_per_epoch): - labels, tokenized = zip(*[ - (label, tokens) - for label, tokens in zip(labels, tokenized) - if all(token in w2v.wv for token in tokens) - ]) + labels, tokenized = zip( + *[ + (label, tokens) + for label, tokens in zip(labels, tokenized) + if all(token in w2v.wv for token in tokens) + ] + ) self.labels = torch.tensor(labels, dtype=int) self.tokenized = np.array(tokenized) self.batches_per_epoch = batches_per_epoch self.batch_size = batch_size self.w2v = w2v - self.texts = set(" ".join(tokens[-cfg.target_len:]) for tokens in tokenized) - print('dataset random texts examples\n', '\n'.join([txt for txt in self.texts][:5])) + self.texts = set(" ".join(tokens[-cfg.target_len :]) for tokens in tokenized) + print( + "dataset random texts examples\n", + "\n".join([txt for txt in self.texts][:5]), + ) def vectorize_batch(self, tokenized): vectors = [ @@ -159,7 +172,7 @@ def vectorize_batch(self, tokenized): tokens, self.w2v, target_len=cfg.target_len, - padding_token = cfg.padding_token, + padding_token=cfg.padding_token, ) for tokens in tokenized ] @@ -178,15 +191,26 @@ def __iter__(self): if index > len(self): # concatenating beginning of self.vectors yield ( - torch.cat((self.labels[index - self.batch_size: index], self.labels[:index-len(self)])), - torch.cat(( - self.vectorize_batch(self.tokenized[index - self.batch_size: index]), - self.vectorize_batch(self.tokenized[:index-len(self)]) - )) + torch.cat( + ( + self.labels[index - self.batch_size : index], + self.labels[: index - len(self)], + ) + ), + torch.cat( + ( + self.vectorize_batch( + self.tokenized[index - self.batch_size : index] + ), + self.vectorize_batch(self.tokenized[: index - len(self)]), + ) + ), ) index = index % len(self) else: - yield self.labels[index - self.batch_size: index], self.vectorize_batch(self.tokenized[index - self.batch_size: index]) + yield self.labels[ + index - self.batch_size : index + ], self.vectorize_batch(self.tokenized[index - self.batch_size : index]) def __len__(self): return self.batches_per_epoch diff --git a/utils/data_utils.py b/utils/data_utils.py index 7cea33a1..164199fb 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -18,29 +18,45 @@ def create_multi_oracle(number): for i in range(number): - print('Creating Oracle %d...' % i) - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + print("Creating Oracle %d..." % i) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size) torch.save(oracle.state_dict(), cfg.multi_oracle_state_dict_path.format(i)) - torch.save(large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) - torch.save(small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2)) + torch.save( + large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num) + ) + torch.save( + small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2) + ) oracle_data = GenDataIter(large_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) - print('Oracle %d Groud Truth: %.4f' % (i, groud_truth)) + print("Oracle %d Groud Truth: %.4f" % (i, groud_truth)) -def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): +def create_specific_oracle(from_a, to_b, num=1, save_path="../pretrain/"): for i in range(num): while True: - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -52,24 +68,36 @@ def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: - dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format(groud_truth, - strftime("%m%d_%H%M%S", localtime())) + dir_path = save_path + "oracle_data_gt{:.2f}_{}".format( + groud_truth, strftime("%m%d_%H%M%S", localtime()) + ) if not os.path.exists(dir_path): os.mkdir(dir_path) - print('save ground truth: ', groud_truth) + print("save ground truth: ", groud_truth) # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime())) - prefix = dir_path + '/oracle_lstm' - torch.save(oracle.state_dict(), '{}.pt'.format(prefix)) - torch.save(big_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) - torch.save(small_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) + prefix = dir_path + "/oracle_lstm" + torch.save(oracle.state_dict(), "{}.pt".format(prefix)) + torch.save( + big_samples, "{}_samples_{}.pt".format(prefix, cfg.samples_num) + ) + torch.save( + small_samples, + "{}_samples_{}.pt".format(prefix, cfg.samples_num // 2), + ) break -def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): +def create_many_oracle(from_a, to_b, num=1, save_path="../pretrain/"): for i in range(num): while True: - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -81,33 +109,39 @@ def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: - print('save ground truth: ', groud_truth) - prefix = 'oracle_lstm' - torch.save(oracle.state_dict(), save_path + '{}.pt'.format(prefix)) - torch.save(big_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) - torch.save(small_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) + print("save ground truth: ", groud_truth) + prefix = "oracle_lstm" + torch.save(oracle.state_dict(), save_path + "{}.pt".format(prefix)) + torch.save( + big_samples, + save_path + "{}_samples_{}.pt".format(prefix, cfg.samples_num), + ) + torch.save( + small_samples, + save_path + "{}_samples_{}.pt".format(prefix, cfg.samples_num // 2), + ) break def _save(data, filename): - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for d in data: - fout.write(d['reviewText'] + '\n') - fout.write(str(d['overall']) + '\n') + fout.write(d["reviewText"] + "\n") + fout.write(str(d["overall"]) + "\n") def _count(filename): - with open(filename, 'r') as fin: - data = fin.read().strip().split('\n') + with open(filename, "r") as fin: + data = fin.read().strip().split("\n") return len(data) / 2 def clean_amazon_long_sentence(): - data_root = '/home/sysu2018/Documents/william/amazon_dataset/' + data_root = "/home/sysu2018/Documents/william/amazon_dataset/" all_files = os.listdir(data_root) - print('|\ttype\t|\torigin\t|\tclean_40\t|\tclean_20\t|\tfinal_40\t|\tfinal_20\t|') - print('|----------|----------|----------|----------|----------|----------|') + print("|\ttype\t|\torigin\t|\tclean_40\t|\tclean_20\t|\tfinal_40\t|\tfinal_20\t|") + print("|----------|----------|----------|----------|----------|----------|") for file in all_files: filename = data_root + file if os.path.isdir(filename): @@ -117,37 +151,44 @@ def clean_amazon_long_sentence(): clean_save_20 = [] final_save_40 = [] final_save_20 = [] - with open(filename, 'r') as fin: - raw_data = fin.read().strip().split('\n') + with open(filename, "r") as fin: + raw_data = fin.read().strip().split("\n") for line in raw_data: - review = eval(line)['reviewText'] + review = eval(line)["reviewText"] if len(review.split()) <= 40: clean_save_40.append(eval(line)) - if len(review.split('.')) <= 2: # one sentence + if len(review.split(".")) <= 2: # one sentence final_save_40.append(eval(line)) if len(review.split()) <= 20: clean_save_20.append(eval(line)) - if len(review.split('.')) <= 2: # one sentence + if len(review.split(".")) <= 2: # one sentence final_save_20.append(eval(line)) - save_filename = data_root + 'clean_40/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "clean_40/" + file.lower().split("_5")[0] + ".txt" _save(clean_save_40, save_filename) # a = _count(save_filename) - save_filename = data_root + 'clean_20/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "clean_20/" + file.lower().split("_5")[0] + ".txt" _save(clean_save_20, save_filename) # b = _count(save_filename) - save_filename = data_root + 'final_40/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "final_40/" + file.lower().split("_5")[0] + ".txt" _save(final_save_40, save_filename) # c = _count(save_filename) - save_filename = data_root + 'final_20/' + file.lower().split('_5')[0] + '.txt' + save_filename = data_root + "final_20/" + file.lower().split("_5")[0] + ".txt" _save(final_save_20, save_filename) # d = _count(save_filename) - print('|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|' % ( - file.lower().split('_5')[0], len(raw_data), - len(clean_save_40), len(clean_save_20), - len(final_save_40), len(final_save_20))) + print( + "|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|" + % ( + file.lower().split("_5")[0], + len(raw_data), + len(clean_save_40), + len(clean_save_20), + len(final_save_40), + len(final_save_20), + ) + ) # print('|\t%s\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|\t%d\t|' % ( # file.lower().split('_5')[0], len(raw_data), a, b, c, d)) @@ -163,5 +204,5 @@ def mean_list(x, y): return res -if __name__ == '__main__': +if __name__ == "__main__": pass diff --git a/utils/gan_loss.py b/utils/gan_loss.py index e799c17d..01603bbb 100644 --- a/utils/gan_loss.py +++ b/utils/gan_loss.py @@ -13,6 +13,7 @@ import config as cfg from utils.nn_helpers import DiversityLoss + class GANLoss(nn.Module): """Define different GAN Discriminator's objectives. @@ -20,8 +21,16 @@ class GANLoss(nn.Module): that has the same size as the input. """ - def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_fake_label=0.0, CUDA=False): - """ Initialize the GAN's Discriminator Loss class. + def __init__( + self, + loss_mode, + which_net, + which_D, + target_real_label=1.0, + target_fake_label=0.0, + CUDA=False, + ): + """Initialize the GAN's Discriminator Loss class. Parameters: loss_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. @@ -32,25 +41,25 @@ def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.register_buffer("real_label", torch.tensor(target_real_label)) + self.register_buffer("fake_label", torch.tensor(target_fake_label)) self.loss_mode = loss_mode self.which_net = which_net self.which_D = which_D self.gpu = CUDA - if loss_mode == 'lsgan': + if loss_mode == "lsgan": self.loss = nn.MSELoss() - elif loss_mode in ['vanilla', 'ragan', 'rsgan']: + elif loss_mode in ["vanilla", "ragan", "rsgan"]: self.loss = nn.BCEWithLogitsLoss() - elif loss_mode in ['wgan', 'hinge']: + elif loss_mode in ["wgan", "hinge"]: self.loss = None - elif loss_mode == 'fixem': + elif loss_mode == "fixem": self.real_fake_criterion = nn.BCEWithLogitsLoss() self.label_criterion = nn.CrossEntropyLoss(label_smoothing=0.1) self.diversity_criterion = DiversityLoss() else: - raise NotImplementedError('gan mode %s not implemented' % loss_mode) + raise NotImplementedError("gan mode %s not implemented" % loss_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. @@ -69,104 +78,127 @@ def get_target_tensor(self, prediction, target_is_real): return target_tensor.expand_as(prediction) def G_loss(self, Dreal, Dfake): - if self.loss_mode != 'rsgan' and cfg.d_out_mean: + if self.loss_mode != "rsgan" and cfg.d_out_mean: Dfake = torch.mean(Dfake.view(cfg.batch_size, -1), dim=-1) Dreal = torch.mean(Dreal.view(cfg.batch_size, -1), dim=-1) real_tensor = self.get_target_tensor(Dreal, True) fake_tensor = self.get_target_tensor(Dreal, False) - if self.which_D == 'S': + if self.which_D == "S": prediction_fake = Dfake - prediction_real = real_tensor if self.loss_mode in ['vanilla'] else fake_tensor - elif self.which_D == 'Ra': + prediction_real = ( + real_tensor if self.loss_mode in ["vanilla"] else fake_tensor + ) + elif self.which_D == "Ra": prediction_fake = Dfake - torch.mean(Dreal) prediction_real = Dreal - torch.mean(Dfake) else: - raise NotImplementedError('which_D name [%s] is not recognized' % self.which_D) + raise NotImplementedError( + "which_D name [%s] is not recognized" % self.which_D + ) - if self.loss_mode in ['lsgan', 'ragan']: + if self.loss_mode in ["lsgan", "ragan"]: loss_fake = self.loss(prediction_fake, real_tensor) loss_real = self.loss(prediction_real, fake_tensor) g_loss = loss_fake + loss_real - elif self.loss_mode == 'vanilla': + elif self.loss_mode == "vanilla": loss_fake = -self.loss(prediction_fake, fake_tensor) g_loss = loss_fake - elif self.loss_mode in ['wgan', 'hinge'] and self.which_D == 'S': + elif self.loss_mode in ["wgan", "hinge"] and self.which_D == "S": loss_fake = -prediction_fake.mean() loss_real = prediction_real.mean() g_loss = loss_fake + loss_real - elif self.loss_mode == 'hinge' and self.which_D == 'Ra': + elif self.loss_mode == "hinge" and self.which_D == "Ra": loss_fake = nn.ReLU()(1.0 - prediction_fake).mean() loss_real = nn.ReLU()(1.0 + prediction_real).mean() g_loss = loss_fake + loss_real - elif self.loss_mode == 'rsgan': + elif self.loss_mode == "rsgan": loss_fake = self.loss(Dfake - Dreal, real_tensor) g_loss = loss_fake else: - raise NotImplementedError('loss_mode name [%s] is not recognized' % self.loss_mode) + raise NotImplementedError( + "loss_mode name [%s] is not recognized" % self.loss_mode + ) return g_loss def D_loss(self, Dreal, Dfake): - if self.loss_mode != 'rsgan' and cfg.d_out_mean: + if self.loss_mode != "rsgan" and cfg.d_out_mean: Dfake = torch.mean(Dfake.view(cfg.batch_size, -1), dim=-1) Dreal = torch.mean(Dreal.view(cfg.batch_size, -1), dim=-1) real_tensor = self.get_target_tensor(Dreal, True) fake_tensor = self.get_target_tensor(Dreal, False) - if self.which_D == 'S': + if self.which_D == "S": prediction_fake = Dfake prediction_real = Dreal - elif self.which_D == 'Ra': + elif self.which_D == "Ra": prediction_fake = Dfake - torch.mean(Dreal) prediction_real = Dreal - torch.mean(Dfake) else: - raise NotImplementedError('which_D name [%s] is not recognized' % self.which_D) + raise NotImplementedError( + "which_D name [%s] is not recognized" % self.which_D + ) - if self.loss_mode in ['lsgan', 'ragan', 'vanilla']: + if self.loss_mode in ["lsgan", "ragan", "vanilla"]: loss_fake = self.loss(prediction_fake, fake_tensor) loss_real = self.loss(prediction_real, real_tensor) - elif self.loss_mode == 'wgan': + elif self.loss_mode == "wgan": loss_fake = prediction_fake.mean() loss_real = -prediction_real.mean() - elif self.loss_mode == 'hinge': + elif self.loss_mode == "hinge": loss_fake = nn.ReLU()(1.0 + prediction_fake).mean() loss_real = nn.ReLU()(1.0 - prediction_real).mean() - elif self.loss_mode == 'rsgan': - loss_fake = 0. + elif self.loss_mode == "rsgan": + loss_fake = 0.0 loss_real = self.loss(Dreal - Dfake, real_tensor) else: - raise NotImplementedError('loss_mode name [%s] is not recognized' % self.loss_mode) + raise NotImplementedError( + "loss_mode name [%s] is not recognized" % self.loss_mode + ) return loss_fake + loss_real - def G_loss_fixem(self, real_fake_predicts, label_predicts, target_labels, fakes): target_fake = self.get_target_tensor(real_fake_predicts, target_is_real=True) - real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion(real_fake_predicts, target_fake) - labels_loss = cfg.labels_coeff * self.label_criterion(label_predicts, target_labels) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion( + real_fake_predicts, target_fake + ) + labels_loss = cfg.labels_coeff * self.label_criterion( + label_predicts, target_labels + ) diversity_loss = cfg.diversity_coeff * self.diversity_criterion(fakes) loss = real_fake_loss + diversity_loss - loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss + loss = loss + labels_loss if cfg.run_model == "cat_fixemgan" else loss return loss def D_loss_fixem(self, real_fake_predicts, label_predicts, target_labels): - target_real = self.get_target_tensor(real_fake_predicts.chunk(2)[0], target_is_real=True) - target_fake = self.get_target_tensor(real_fake_predicts.chunk(2)[1], target_is_real=False) + target_real = self.get_target_tensor( + real_fake_predicts.chunk(2)[0], target_is_real=True + ) + target_fake = self.get_target_tensor( + real_fake_predicts.chunk(2)[1], target_is_real=False + ) target_real_fake = torch.cat((target_real, target_fake)) - real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion(real_fake_predicts, target_real_fake) - labels_loss = cfg.labels_coeff * self.label_criterion(label_predicts, target_labels) + real_fake_loss = cfg.real_fake_coeff * self.real_fake_criterion( + real_fake_predicts, target_real_fake + ) + labels_loss = cfg.labels_coeff * self.label_criterion( + label_predicts, target_labels + ) loss = real_fake_loss - loss = loss + labels_loss if cfg.run_model == 'cat_fixemgan' else loss + loss = loss + labels_loss if cfg.run_model == "cat_fixemgan" else loss return loss def __call__(self, Dreal, Dfake): """Calculate loss given Discriminator's output and grount truth labels.""" - if self.which_net == 'G': + if self.which_net == "G": return self.G_loss(Dreal, Dfake) - elif self.which_net == 'D': + elif self.which_net == "D": return self.D_loss(Dreal, Dfake) else: - raise NotImplementedError('which_net name [%s] is not recognized' % self.which_net) + raise NotImplementedError( + "which_net name [%s] is not recognized" % self.which_net + ) diff --git a/utils/helpers.py b/utils/helpers.py index 7183e2c0..42e4b5ef 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -23,11 +23,11 @@ def __init__(self, signal_file): def update(self): signal_dict = self.read_signal() - self.pre_sig = signal_dict['pre_sig'] - self.adv_sig = signal_dict['adv_sig'] + self.pre_sig = signal_dict["pre_sig"] + self.adv_sig = signal_dict["adv_sig"] def read_signal(self): - with open(self.signal_file, 'r') as fin: + with open(self.signal_file, "r") as fin: return eval(fin.read()) @@ -37,22 +37,26 @@ def create_logger(name, silent=False, to_disk=False, log_file=None): log = logging.getLogger(name) log.setLevel(logging.DEBUG) log.propagate = False - formatter = logging.Formatter(fmt='%(message)s', datefmt='%Y/%m/%d %I:%M:%S') + formatter = logging.Formatter(fmt="%(message)s", datefmt="%Y/%m/%d %I:%M:%S") if not silent: ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) log.addHandler(ch) if to_disk: - log_file = log_file if log_file is not None else strftime("log/log_%m%d_%H%M.txt", gmtime()) + log_file = ( + log_file + if log_file is not None + else strftime("log/log_%m%d_%H%M.txt", gmtime()) + ) if type(log_file) == list: for filename in log_file: - fh = logging.FileHandler(filename, mode='w') + fh = logging.FileHandler(filename, mode="w") fh.setLevel(logging.INFO) fh.setFormatter(formatter) log.addHandler(fh) if type(log_file) == str: - fh = logging.FileHandler(log_file, mode='w') + fh = logging.FileHandler(log_file, mode="w") fh.setLevel(logging.INFO) fh.setFormatter(formatter) log.addHandler(fh) @@ -64,9 +68,15 @@ def create_oracle(): import config as cfg from models.Oracle import Oracle - print('Creating Oracle...') - oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, - cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) + print("Creating Oracle...") + oracle = Oracle( + cfg.gen_embed_dim, + cfg.gen_hidden_dim, + cfg.vocab_size, + cfg.max_seq_len, + cfg.padding_idx, + gpu=cfg.CUDA, + ) if cfg.CUDA: oracle = oracle.cuda() @@ -76,12 +86,14 @@ def create_oracle(): # large torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num)) # small - torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), - cfg.oracle_samples_path.format(cfg.samples_num // 2)) + torch.save( + oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), + cfg.oracle_samples_path.format(cfg.samples_num // 2), + ) - #giant for W2V + # giant for W2V giant_samples = oracle.sample(cfg.w2v_samples_num, 4 * cfg.batch_size) - with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), 'w') as f: + with open(cfg.oracle_samples_path.format(cfg.w2v_samples_num), "w") as f: for sample in tqdm(giant_samples): f.write(" ".join(str(int(idx)) for idx in sample)) f.write("\n") @@ -89,26 +101,30 @@ def create_oracle(): oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) - print('NLL_Oracle Groud Truth: %.4f' % groud_truth) + print("NLL_Oracle Groud Truth: %.4f" % groud_truth) def get_fixed_temperature(temper, i, N, adapt): """A function to set up different temperature control policies""" N = 5000 - if adapt == 'no': + if adapt == "no": temper_var_np = 1.0 # no increase, origin: temper - elif adapt == 'lin': + elif adapt == "lin": temper_var_np = 1 + i / (N - 1) * (temper - 1) # linear increase - elif adapt == 'exp': + elif adapt == "exp": temper_var_np = temper ** (i / N) # exponential increase - elif adapt == 'log': - temper_var_np = 1 + (temper - 1) / np.log(N) * np.log(i + 1) # logarithm increase - elif adapt == 'sigmoid': - temper_var_np = (temper - 1) * 1 / (1 + np.exp((N / 2 - i) * 20 / N)) + 1 # sigmoid increase - elif adapt == 'quad': - temper_var_np = (temper - 1) / (N - 1) ** 2 * i ** 2 + 1 - elif adapt == 'sqrt': + elif adapt == "log": + temper_var_np = 1 + (temper - 1) / np.log(N) * np.log( + i + 1 + ) # logarithm increase + elif adapt == "sigmoid": + temper_var_np = (temper - 1) * 1 / ( + 1 + np.exp((N / 2 - i) * 20 / N) + ) + 1 # sigmoid increase + elif adapt == "quad": + temper_var_np = (temper - 1) / (N - 1) ** 2 * i**2 + 1 + elif adapt == "sqrt": temper_var_np = (temper - 1) / np.sqrt(N - 1) * np.sqrt(i) + 1 else: raise Exception("Unknown adapt type!") @@ -116,43 +132,43 @@ def get_fixed_temperature(temper, i, N, adapt): return temper_var_np -def get_losses(d_out_real, d_out_fake, loss_type='JS'): +def get_losses(d_out_real, d_out_fake, loss_type="JS"): """Get different adversarial losses according to given loss_type""" bce_loss = nn.BCEWithLogitsLoss() - if loss_type == 'standard': # the non-satuating GAN loss + if loss_type == "standard": # the non-satuating GAN loss d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = bce_loss(d_out_fake, torch.ones_like(d_out_fake)) - elif loss_type == 'JS': # the vanilla GAN loss + elif loss_type == "JS": # the vanilla GAN loss d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = -d_loss_fake - elif loss_type == 'KL': # the GAN loss implicitly minimizing KL-divergence + elif loss_type == "KL": # the GAN loss implicitly minimizing KL-divergence d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real)) d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = torch.mean(-d_out_fake) - elif loss_type == 'hinge': # the hinge loss + elif loss_type == "hinge": # the hinge loss d_loss_real = torch.mean(nn.ReLU(1.0 - d_out_real)) d_loss_fake = torch.mean(nn.ReLU(1.0 + d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = -torch.mean(d_out_fake) - elif loss_type == 'tv': # the total variation distance + elif loss_type == "tv": # the total variation distance d_loss = torch.mean(nn.Tanh(d_out_fake) - nn.Tanh(d_out_real)) g_loss = torch.mean(-nn.Tanh(d_out_fake)) - elif loss_type == 'rsgan': # relativistic standard GAN + elif loss_type == "rsgan": # relativistic standard GAN d_loss = bce_loss(d_out_real - d_out_fake, torch.ones_like(d_out_real)) g_loss = bce_loss(d_out_fake - d_out_real, torch.ones_like(d_out_fake)) diff --git a/utils/nn_helpers.py b/utils/nn_helpers.py index ac25ccff..6dde8240 100644 --- a/utils/nn_helpers.py +++ b/utils/nn_helpers.py @@ -193,10 +193,11 @@ def __init__( nn.BatchNorm1d(2 * out_channels), nn.LeakyReLU(alpha), ) + def forward(self, x): x = torch.transpose(x, 1, 2) x, (hn, cn) = self.lstm(x) - x = torch.transpose(x, 1,2) + x = torch.transpose(x, 1, 2) x = self.layers(x) return x @@ -257,24 +258,26 @@ def __init__( nn.LeakyReLU(alpha), nn.Dropout(drop_rate), ) + def forward(self, x): x = torch.transpose(x, 1, 2) x, (hn, cn) = self.lstm(x) - x = torch.transpose(x, 1,2) + x = torch.transpose(x, 1, 2) x = self.layers(x) return x - def DiversityLoss(): cs2 = torch.nn.CosineSimilarity(dim=2) + def cos_sim_loss(generated): batch_size = generated.shape[0] generated = generated.repeat(batch_size, 1, 1, 1) generatedTranspose = torch.transpose(generated, 0, 1) loss = cs2(generated, generatedTranspose) ind = np.diag_indices(loss.shape[0]) - loss[ind[0], ind[1], :] = 0 # set 0 to similarity of message to itself + loss[ind[0], ind[1], :] = 0 # set 0 to similarity of message to itself loss = loss.mean(axis=2).max(axis=0).values.mean() return loss + return cos_sim_loss diff --git a/utils/rollout.py b/utils/rollout.py index 6e165b21..96c46e19 100644 --- a/utils/rollout.py +++ b/utils/rollout.py @@ -4,7 +4,7 @@ # @FileName : rollout.py # @Time : Created at 2019-03-15 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import copy @@ -18,8 +18,8 @@ def __init__(self, gen, gpu=True): self.old_model = copy.deepcopy(gen) self.max_seq_len = gen.max_seq_len self.vocab_size = gen.vocab_size - self.step_size = gen.step_size if gen.name == 'leakgan' else 0 - self.goal_out_size = gen.goal_out_size if gen.name == 'leakgan' else 0 + self.step_size = gen.step_size if gen.name == "leakgan" else 0 + self.goal_out_size = gen.goal_out_size if gen.name == "leakgan" else 0 self.gpu = gpu def rollout_mc_search(self, sentences, given_num): @@ -73,7 +73,7 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): for i in range(given_num): # Get feature. dis_inp = torch.zeros(batch_size, seq_len).long() - dis_inp[:, :i + 1] = sentences[:, :i + 1] # cut sentences + dis_inp[:, : i + 1] = sentences[:, : i + 1] # cut sentences leak_inp = sentences[:, i] if self.gpu: dis_inp = dis_inp.cuda() @@ -82,13 +82,14 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.gen(i, leak_inp, work_hidden, mana_hidden, - feature, real_goal, train=True) + out, cur_goal, work_hidden, mana_hidden = self.gen( + i, leak_inp, work_hidden, mana_hidden, feature, real_goal, train=True + ) # Save goal and update last_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.gen.goal_init[:batch_size, :] @@ -98,7 +99,9 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # MC search for i in range(given_num, self.max_seq_len): # Sample one token - out = torch.multinomial(torch.exp(out), 1).view(-1) # [num_samples] (sampling from each row) + out = torch.multinomial(torch.exp(out), 1).view( + -1 + ) # [num_samples] (sampling from each row) samples[:, i] = out.data # Get feature @@ -110,13 +113,14 @@ def rollout_mc_search_leakgan(self, sentences, dis, given_num): # Get output of one token # cur_goal: batch_size * 1 * goal_out_size - out, cur_goal, work_hidden, mana_hidden = self.gen(i, leak_inp, work_hidden, mana_hidden, - feature, real_goal, train=True) + out, cur_goal, work_hidden, mana_hidden = self.gen( + i, leak_inp, work_hidden, mana_hidden, feature, real_goal, train=True + ) # Save goal and update last_goal goal_array[:, i, :] = cur_goal.squeeze(1) if i > 0 and i % self.step_size == 0: - real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1) + real_goal = torch.sum(goal_array[:, i - 3 : i + 1, :], dim=1) if i / self.step_size == 1: real_goal += self.gen.goal_init[:batch_size, :] @@ -150,7 +154,9 @@ def get_reward(self, sentences, rollout_num, dis, current_k=0): idx += 1 # rewards = torch.mean(rewards, dim=0) - rewards = torch.mean(rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1) + rewards = torch.mean( + rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1 + ) return rewards def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): @@ -165,7 +171,9 @@ def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): """ with torch.no_grad(): batch_size = sentences.size(0) - rewards = torch.zeros([rollout_num * (self.max_seq_len // self.step_size), batch_size]).float() + rewards = torch.zeros( + [rollout_num * (self.max_seq_len // self.step_size), batch_size] + ).float() if self.gpu: rewards = rewards.cuda() idx = 0 @@ -179,7 +187,9 @@ def get_reward_leakgan(self, sentences, rollout_num, dis, current_k): rewards[idx] = reward idx += 1 - rewards = rewards.view(batch_size, self.max_seq_len // self.step_size, rollout_num) + rewards = rewards.view( + batch_size, self.max_seq_len // self.step_size, rollout_num + ) rewards = torch.mean(rewards, dim=-1) return rewards diff --git a/utils/text_process.py b/utils/text_process.py index 9d6053d3..889872fe 100644 --- a/utils/text_process.py +++ b/utils/text_process.py @@ -19,7 +19,7 @@ def text_file_iterator(file): with open(file) as raw: for line in raw.readlines(): - yield line.strip('\n').split() + yield line.strip("\n").split() def get_tokenlized(file): @@ -27,7 +27,7 @@ def get_tokenlized(file): tokenlized = list() with open(file) as raw: for text in raw: - text = text.strip('\n').lower().split() + text = text.strip("\n").lower().split() tokenlized.append(text) return tokenlized @@ -72,7 +72,9 @@ def text_process(train_text_loc, test_text_loc=None): if test_text_loc is None: sequence_len = len(max(train_tokens, key=len)) else: - sequence_len = max(len(max(train_tokens, key=len)), len(max(test_tokens, key=len))) + sequence_len = max( + len(max(train_tokens, key=len)), len(max(test_tokens, key=len)) + ) return sequence_len, len(word2idx_dict) @@ -83,29 +85,31 @@ def init_dict(dataset): Initialize dictionaries of dataset, please note that '0': padding_idx, '1': start_letter. Finally save dictionary files locally. """ - tokens = get_tokenlized('dataset/{}.txt'.format(dataset)) + tokens = get_tokenlized("dataset/{}.txt".format(dataset)) word_set = get_word_list(tokens) word2idx_dict, idx2word_dict = get_dict(word_set) - with open('dataset/{}_wi_dict.txt'.format(dataset), 'w') as dictout: + with open("dataset/{}_wi_dict.txt".format(dataset), "w") as dictout: dictout.write(str(word2idx_dict)) - with open('dataset/{}_iw_dict.txt'.format(dataset), 'w') as dictout: + with open("dataset/{}_iw_dict.txt".format(dataset), "w") as dictout: dictout.write(str(idx2word_dict)) - print('total tokens: ', len(word2idx_dict)) + print("total tokens: ", len(word2idx_dict)) def load_dict(dataset): """Load dictionary from local files""" - iw_path = 'dataset/{}_iw_dict.txt'.format(dataset) - wi_path = 'dataset/{}_wi_dict.txt'.format(dataset) + iw_path = "dataset/{}_iw_dict.txt".format(dataset) + wi_path = "dataset/{}_wi_dict.txt".format(dataset) - if not os.path.exists(iw_path) or not os.path.exists(iw_path): # initialize dictionaries + if not os.path.exists(iw_path) or not os.path.exists( + iw_path + ): # initialize dictionaries init_dict(dataset) - with open(iw_path, 'r') as dictin: + with open(iw_path, "r") as dictin: idx2word_dict = eval(dictin.read().strip()) - with open(wi_path, 'r') as dictin: + with open(wi_path, "r") as dictin: word2idx_dict = eval(dictin.read().strip()) return word2idx_dict, idx2word_dict @@ -115,7 +119,7 @@ def load_test_dict(dataset): """Build test data dictionary, extend from train data. For the classifier.""" word2idx_dict, idx2word_dict = load_dict(dataset) # train dict # tokens = get_tokenlized('dataset/testdata/{}_clas_test.txt'.format(dataset)) - tokens = get_tokenlized('dataset/testdata/{}_test.txt'.format(dataset)) + tokens = get_tokenlized("dataset/testdata/{}_test.txt".format(dataset)) word_set = get_word_list(tokens) index = len(word2idx_dict) # current index @@ -157,7 +161,7 @@ def tokens_to_tensor(tokens, dictionary): while i < cfg.max_seq_len - 1: sent_ten.append(cfg.padding_idx) i += 1 - tensor.append(sent_ten[:cfg.max_seq_len]) + tensor.append(sent_ten[: cfg.max_seq_len]) return torch.LongTensor(tensor) @@ -180,32 +184,32 @@ def padding_token(tokens): def write_tokens(filename, tokens): """Write word tokens to a local file (For Real data)""" - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for sent in tokens: - fout.write(' '.join(sent)) - fout.write('\n') + fout.write(" ".join(sent)) + fout.write("\n") def write_tensor(filename, tensor): """Write Tensor to a local file (For Oracle data)""" - with open(filename, 'w') as fout: + with open(filename, "w") as fout: for sent in tensor: - fout.write(' '.join([str(i) for i in sent.tolist()])) - fout.write('\n') + fout.write(" ".join([str(i) for i in sent.tolist()])) + fout.write("\n") def process_cat_text(): import random - dataset = 'mr' + dataset = "mr" test_ratio = 0.3 seq_len = 15 - pos_file = 'dataset/{}/{}{}_cat1.txt'.format(dataset, dataset, seq_len) - neg_file = 'dataset/{}/{}{}_cat0.txt'.format(dataset, dataset, seq_len) - pos_sent = open(pos_file, 'r').readlines() - neg_sent = open(neg_file, 'r').readlines() + pos_file = "dataset/{}/{}{}_cat1.txt".format(dataset, dataset, seq_len) + neg_file = "dataset/{}/{}{}_cat0.txt".format(dataset, dataset, seq_len) + pos_sent = open(pos_file, "r").readlines() + neg_sent = open(neg_file, "r").readlines() pos_len = int(test_ratio * len(pos_sent)) neg_len = int(test_ratio * len(neg_sent)) @@ -218,10 +222,14 @@ def process_cat_text(): random.shuffle(all_sent_test) random.shuffle(all_sent_train) - f_pos_train = open('dataset/{}{}_cat1.txt'.format(dataset, seq_len), 'w') - f_neg_train = open('dataset/{}{}_cat0.txt'.format(dataset, seq_len), 'w') - f_pos_test = open('dataset/testdata/{}{}_cat1_test.txt'.format(dataset, seq_len), 'w') - f_neg_test = open('dataset/testdata/{}{}_cat0_test.txt'.format(dataset, seq_len), 'w') + f_pos_train = open("dataset/{}{}_cat1.txt".format(dataset, seq_len), "w") + f_neg_train = open("dataset/{}{}_cat0.txt".format(dataset, seq_len), "w") + f_pos_test = open( + "dataset/testdata/{}{}_cat1_test.txt".format(dataset, seq_len), "w" + ) + f_neg_test = open( + "dataset/testdata/{}{}_cat0_test.txt".format(dataset, seq_len), "w" + ) for p_s in pos_sent[:pos_len]: f_pos_test.write(p_s) @@ -232,10 +240,10 @@ def process_cat_text(): for n_s in neg_sent[neg_len:]: f_neg_train.write(n_s) - with open('dataset/testdata/{}{}_test.txt'.format(dataset, seq_len), 'w') as fout: + with open("dataset/testdata/{}{}_test.txt".format(dataset, seq_len), "w") as fout: for sent in all_sent_test: fout.write(sent) - with open('dataset/{}{}.txt'.format(dataset, seq_len), 'w') as fout: + with open("dataset/{}{}.txt".format(dataset, seq_len), "w") as fout: for sent in all_sent_train: fout.write(sent) @@ -246,20 +254,22 @@ def process_cat_text(): def combine_amazon_text(): - cat0_name = 'app' - cat1_name = 'book' - root_path = 'dataset/' - cat0_train = open(root_path + cat0_name + '.txt', 'r').readlines() - cat0_test = open(root_path + cat0_name + '_test.txt', 'r').readlines() - cat1_train = open(root_path + cat1_name + '.txt', 'r').readlines() - cat1_test = open(root_path + cat1_name + '_test.txt', 'r').readlines() - - with open(root_path + 'amazon_{}_{}.txt'.format(cat0_name, cat1_name), 'w') as fout: + cat0_name = "app" + cat1_name = "book" + root_path = "dataset/" + cat0_train = open(root_path + cat0_name + ".txt", "r").readlines() + cat0_test = open(root_path + cat0_name + "_test.txt", "r").readlines() + cat1_train = open(root_path + cat1_name + ".txt", "r").readlines() + cat1_test = open(root_path + cat1_name + "_test.txt", "r").readlines() + + with open(root_path + "amazon_{}_{}.txt".format(cat0_name, cat1_name), "w") as fout: for sent in cat0_train: fout.write(sent) for sent in cat1_train: fout.write(sent) - with open(root_path + 'testdata/amazon_{}_{}_test.txt'.format(cat0_name, cat1_name), 'w') as fout: + with open( + root_path + "testdata/amazon_{}_{}_test.txt".format(cat0_name, cat1_name), "w" + ) as fout: for sent in cat0_test: fout.write(sent) for sent in cat1_test: @@ -267,21 +277,23 @@ def combine_amazon_text(): def extend_clas_train_data(): - data_name = 'mr' - dataset = 'mr20' - neg_filter_file = 'dataset/{}/{}_cat0.txt'.format(data_name, dataset) # include train and test for generator - pos_filter_file = 'dataset/{}/{}_cat1.txt'.format(data_name, dataset) - neg_test_file = 'dataset/testdata/{}_cat0_test.txt'.format(dataset) - pos_test_file = 'dataset/testdata/{}_cat1_test.txt'.format(dataset) - neg_all_file = 'dataset/{}/{}_cat0.txt'.format(data_name, data_name) - pos_all_file = 'dataset/{}/{}_cat1.txt'.format(data_name, data_name) - - neg_filter = open(neg_filter_file, 'r').readlines() - pos_filter = open(pos_filter_file, 'r').readlines() - neg_test = open(neg_test_file, 'r').readlines() - pos_test = open(pos_test_file, 'r').readlines() - neg_all = open(neg_all_file, 'r').readlines() - pos_all = open(pos_all_file, 'r').readlines() + data_name = "mr" + dataset = "mr20" + neg_filter_file = "dataset/{}/{}_cat0.txt".format( + data_name, dataset + ) # include train and test for generator + pos_filter_file = "dataset/{}/{}_cat1.txt".format(data_name, dataset) + neg_test_file = "dataset/testdata/{}_cat0_test.txt".format(dataset) + pos_test_file = "dataset/testdata/{}_cat1_test.txt".format(dataset) + neg_all_file = "dataset/{}/{}_cat0.txt".format(data_name, data_name) + pos_all_file = "dataset/{}/{}_cat1.txt".format(data_name, data_name) + + neg_filter = open(neg_filter_file, "r").readlines() + pos_filter = open(pos_filter_file, "r").readlines() + neg_test = open(neg_test_file, "r").readlines() + pos_test = open(pos_test_file, "r").readlines() + neg_all = open(neg_all_file, "r").readlines() + pos_all = open(pos_all_file, "r").readlines() # print('neg filter:', len(neg_filter)) # print('neg test:', len(neg_test)) @@ -290,62 +302,67 @@ def extend_clas_train_data(): # print('pos test:', len(pos_test)) # print('pos all:', len(pos_all)) - print('neg before:', len(neg_test)) + print("neg before:", len(neg_test)) for line in neg_all: if line not in neg_filter: neg_test.append(line) - print('neg after:', len(neg_test)) + print("neg after:", len(neg_test)) - print('pos before:', len(pos_test)) + print("pos before:", len(pos_test)) for line in pos_all: if line not in pos_filter: pos_test.append(line) - print('pos after:', len(pos_test)) + print("pos after:", len(pos_test)) - with open('dataset/testdata/{}_cat0_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_cat0_clas_test.txt".format(dataset), "w") as fout: for line in neg_test: fout.write(line) - with open('dataset/testdata/{}_cat1_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_cat1_clas_test.txt".format(dataset), "w") as fout: for line in pos_test: fout.write(line) - with open('dataset/testdata/{}_clas_test.txt'.format(dataset), 'w') as fout: + with open("dataset/testdata/{}_clas_test.txt".format(dataset), "w") as fout: for line in neg_test: fout.write(line) for line in pos_test: fout.write(line) -def load_word_vec(path, word2idx_dict=None, type='glove'): +def load_word_vec(path, word2idx_dict=None, type="glove"): """Load word embedding from local file""" - fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') - if type == 'glove': + fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore") + if type == "glove": word2vec_dict = {} for line in fin: tokens = line.rstrip().split() if word2idx_dict is None or tokens[0] in word2idx_dict.keys(): - word2vec_dict[tokens[0]] = np.asarray(tokens[1:], dtype='float32') - elif type == 'word2vec': + word2vec_dict[tokens[0]] = np.asarray(tokens[1:], dtype="float32") + elif type == "word2vec": import gensim - word2vec_dict = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True) + + word2vec_dict = gensim.models.KeyedVectors.load_word2vec_format( + path, binary=True + ) else: - raise NotImplementedError('No such type: %s' % type) + raise NotImplementedError("No such type: %s" % type) return word2vec_dict def build_embedding_matrix(dataset): """Load or build Glove embedding matrix.""" - embed_filename = 'dataset/glove_embedding_300d_{}.pt'.format(dataset) + embed_filename = "dataset/glove_embedding_300d_{}.pt".format(dataset) if os.path.exists(embed_filename): - print('Loading embedding:', embed_filename) + print("Loading embedding:", embed_filename) embedding_matrix = torch.load(embed_filename) else: - print('Loading Glove word vectors...') + print("Loading Glove word vectors...") word2idx_dict, _ = load_dict(dataset) - embedding_matrix = np.random.random((len(word2idx_dict) + 2, 300)) # 2 for padding token and start token - fname = '../glove.42B.300d.txt' # Glove file + embedding_matrix = np.random.random( + (len(word2idx_dict) + 2, 300) + ) # 2 for padding token and start token + fname = "../glove.42B.300d.txt" # Glove file # fname = '../GoogleNews-vectors-negative300.bin' # Google Word2Vec file - word2vec_dict = load_word_vec(fname, word2idx_dict=word2idx_dict, type='glove') - print('Building embedding matrix:', embed_filename) + word2vec_dict = load_word_vec(fname, word2idx_dict=word2idx_dict, type="glove") + print("Building embedding matrix:", embed_filename) for word, i in word2idx_dict.items(): if word in word2vec_dict: # words not found in embedding index will be randomly initialized. @@ -354,8 +371,9 @@ def build_embedding_matrix(dataset): torch.save(embedding_matrix, embed_filename) return embedding_matrix + def pad_sequences( - sequence, w2v, target_len: int = 52, embedding_size: int = 300, padding_token = None + sequence, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None ) -> np.array: sequence = np.array(sequence) current_length = sequence.shape[0] @@ -363,11 +381,19 @@ def pad_sequences( if current_length >= target_len: return sequence[-target_len:] - padding = np.repeat(np.array([w2v.wv[padding_token]]), target_len - current_length,axis=0) if padding_token else np.zeros((target_len - current_length, embedding_size)) + padding = ( + np.repeat( + np.array([w2v.wv[padding_token]]), target_len - current_length, axis=0 + ) + if padding_token + else np.zeros((target_len - current_length, embedding_size)) + ) return np.concatenate((padding, sequence), axis=0) -def vectorize_sentence(tokens, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None): +def vectorize_sentence( + tokens, w2v, target_len: int = 52, embedding_size: int = 300, padding_token=None +): vectorized = pad_sequences( [w2v.wv[token] for token in tokens], w2v, diff --git a/utils/visualization.py b/utils/visualization.py index 10a05a83..b32ba562 100644 --- a/utils/visualization.py +++ b/utils/visualization.py @@ -4,40 +4,57 @@ # @FileName : visualization.py # @Time : Created at 2019-03-19 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import matplotlib.pyplot as plt title_dict = { - 'gen_pre_loss': 'pre_loss', - 'gen_adv_loss': 'g_loss', - 'gen_mana_loss': 'mana_loss', - 'gen_work_loss': 'work_loss', - 'dis_loss': 'd_loss', - 'dis_train_acc': 'train_acc', - 'dis_eval_acc': 'eval_acc', - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'BLEU-3': 'BLEU-3', + "gen_pre_loss": "pre_loss", + "gen_adv_loss": "g_loss", + "gen_mana_loss": "mana_loss", + "gen_work_loss": "work_loss", + "dis_loss": "d_loss", + "dis_train_acc": "train_acc", + "dis_eval_acc": "eval_acc", + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "BLEU-3": "BLEU-3", } -color_list = ['#e74c3c', '#e67e22', '#f1c40f', '#8e44ad', '#2980b9', '#27ae60', '#16a085'] +color_list = [ + "#e74c3c", + "#e67e22", + "#f1c40f", + "#8e44ad", + "#2980b9", + "#27ae60", + "#16a085", +] def plt_data(data, step, title, c_id, savefig=False): x = [i for i in range(step)] plt.plot(x, data, color=color_list[c_id], label=title) if savefig: - plt.savefig('savefig/' + title + '.png') + plt.savefig("savefig/" + title + ".png") def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'pre_loss': [], 'g_loss': [], 'mana_loss': [], 'work_loss': [], - 'd_loss': [], 'train_acc': [], 'eval_acc': [], 'NLL_oracle': [], - 'NLL_gen': [], 'BLEU-3': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = { + "pre_loss": [], + "g_loss": [], + "mana_loss": [], + "work_loss": [], + "d_loss": [], + "train_acc": [], + "eval_acc": [], + "NLL_oracle": [], + "NLL_gen": [], + "BLEU-3": [], + } for line in all_lines: items = line.split() @@ -51,28 +68,33 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - log_file_root = '../log/' +if __name__ == "__main__": + log_file_root = "../log/" # Custom your log files in lists, no more than len(color_list) - log_file_list = ['log_0604_2233', 'log_0605_0120', 'log_0531_1507'] - legend_text = ['SeqGAN', 'LeakGAN', 'RelGAN'] + log_file_list = ["log_0604_2233", "log_0605_0120", "log_0531_1507"] + legend_text = ["SeqGAN", "LeakGAN", "RelGAN"] color_id = 0 - data_name = 'NLL_oracle' + data_name = "NLL_oracle" if_save = False # legend_text = log_file_list - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.clf() plt.title(data_name) all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - plt_data(all_data[title_dict[data_name]], len(all_data[title_dict[data_name]]), - legend_text[idx], color_id, if_save) + plt_data( + all_data[title_dict[data_name]], + len(all_data[title_dict[data_name]]), + legend_text[idx], + color_id, + if_save, + ) color_id += 1 plt.legend() diff --git a/visual/visual_human.py b/visual/visual_human.py index 96bc6c43..73d1f8a8 100644 --- a/visual/visual_human.py +++ b/visual/visual_human.py @@ -25,48 +25,120 @@ bar_width = 0.5 opacity = 1.0 -error_config = {'ecolor': '0'} +error_config = {"ecolor": "0"} -rects1 = ax.bar(0, CSGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#8e44ad', error_kw=error_config, - label='CSGAN') +rects1 = ax.bar( + 0, + CSGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#8e44ad", + error_kw=error_config, + label="CSGAN", +) -rects2 = ax.bar(bar_width, SentiGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#27ae60', error_kw=error_config, - label='SentiGAN') +rects2 = ax.bar( + bar_width, + SentiGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#27ae60", + error_kw=error_config, + label="SentiGAN", +) -rects3 = ax.bar(0 + 2 * bar_width, CatGAN_m, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#d35400', error_kw=error_config, - label='CatGAN ($k=2$)') +rects3 = ax.bar( + 0 + 2 * bar_width, + CatGAN_m, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#d35400", + error_kw=error_config, + label="CatGAN ($k=2$)", +) gap = 1.2 -rects4 = ax.bar(3 * bar_width + gap, SeqGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#fd79a8', error_kw=error_config, - label='SeqGAN') +rects4 = ax.bar( + 3 * bar_width + gap, + SeqGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#fd79a8", + error_kw=error_config, + label="SeqGAN", +) -rects5 = ax.bar(4 * bar_width + gap, RankGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#34495e', error_kw=error_config, - label='RankGAN') +rects5 = ax.bar( + 4 * bar_width + gap, + RankGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#34495e", + error_kw=error_config, + label="RankGAN", +) -rects6 = ax.bar(0 + 5 * bar_width + gap, LeakGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#f1c40f', error_kw=error_config, - label='LeakGAN') -rects7 = ax.bar(6 * bar_width + gap, RelGAN, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#2980b9', error_kw=error_config, - label='RelGAN') -rects8 = ax.bar(7 * bar_width + gap, CatGAN_s, bar_width, linestyle='-', linewidth=1, edgecolor='black', - alpha=opacity, color='#c0392b', error_kw=error_config, - label='CatGAN ($k=1$)') +rects6 = ax.bar( + 0 + 5 * bar_width + gap, + LeakGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#f1c40f", + error_kw=error_config, + label="LeakGAN", +) +rects7 = ax.bar( + 6 * bar_width + gap, + RelGAN, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#2980b9", + error_kw=error_config, + label="RelGAN", +) +rects8 = ax.bar( + 7 * bar_width + gap, + CatGAN_s, + bar_width, + linestyle="-", + linewidth=1, + edgecolor="black", + alpha=opacity, + color="#c0392b", + error_kw=error_config, + label="CatGAN ($k=1$)", +) -ax.set_xlabel('Dataset') -ax.set_ylabel('Human Score') +ax.set_xlabel("Dataset") +ax.set_ylabel("Human Score") # ax.set_title('Scores by group and gender') len = ((0 + 3 * bar_width) / 3, 3 * bar_width + gap + 2 * bar_width) ax.set_xticks(len) -ax.set_xticklabels(('AR', 'EN')) +ax.set_xticklabels(("AR", "EN")) ax.legend(bbox_to_anchor=(1, 0), loc=3, borderaxespad=0.2) # plt.legend() fig.tight_layout() -plt.savefig('savefig/human.pdf') +plt.savefig("savefig/human.pdf") plt.show() # plt.savefig('C:/1123.pdf') diff --git a/visual/visual_metric.py b/visual/visual_metric.py index c4ca3eed..08107741 100644 --- a/visual/visual_metric.py +++ b/visual/visual_metric.py @@ -4,13 +4,13 @@ # @FileName : visual_metric.py # @Time : Created at 2019-11-26 # @Blog : http://zhiweil.ml/ -# @Description : +# @Description : # Copyrights (C) 2018. All Rights Reserved. import matplotlib.pyplot as plt import numpy as np -color_list = ['#2980b9', '#e74c3c', '#1abc9c', '#9b59b6'] +color_list = ["#2980b9", "#e74c3c", "#1abc9c", "#9b59b6"] def plt_x_y_data(x, y, title, c_id): @@ -18,9 +18,9 @@ def plt_x_y_data(x, y, title, c_id): def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": []} for line in all_lines: items = line.split() @@ -34,14 +34,14 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - log_file_root = 'log/' +if __name__ == "__main__": + log_file_root = "log/" # Custom your log files in lists, no more than len(color_list) - log_file_list = ['jsdgan_vanilla_oracle', 'catgan_vanilla_oracle'] - legend_text = ['JSDGAN', 'CatGAN'] + log_file_list = ["jsdgan_vanilla_oracle", "catgan_vanilla_oracle"] + legend_text = ["JSDGAN", "CatGAN"] color_id = 0 - title = 'Synthetic data' + title = "Synthetic data" if_save = True length = 100 @@ -49,19 +49,23 @@ def get_log_data(filename): plt.title(title) all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - idxs = np.argsort(-np.array(all_data['NLL_oracle'])) - plt_x_y_data(np.array(all_data['NLL_oracle'])[idxs][:length], np.array(all_data['NLL_div'])[idxs][:length], - legend_text[idx], color_id) + idxs = np.argsort(-np.array(all_data["NLL_oracle"])) + plt_x_y_data( + np.array(all_data["NLL_oracle"])[idxs][:length], + np.array(all_data["NLL_div"])[idxs][:length], + legend_text[idx], + color_id, + ) color_id += 1 plt.legend() # plt.tight_layout() - plt.xlabel(r'${\rm NLL_{\rm oracle}}$') - plt.ylabel(r'${\rm NLL_{\rm div}}$') + plt.xlabel(r"${\rm NLL_{\rm oracle}}$") + plt.ylabel(r"${\rm NLL_{\rm div}}$") if if_save: - plt.savefig('../savefig/synthetic_oracle_div.png') + plt.savefig("../savefig/synthetic_oracle_div.png") plt.show() diff --git a/visual/visual_temp_appendix.py b/visual/visual_temp_appendix.py index bea9d1ae..deb553eb 100644 --- a/visual/visual_temp_appendix.py +++ b/visual/visual_temp_appendix.py @@ -5,39 +5,39 @@ import os title_dict = { - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'NLL_div': 'NLL_div', - 'nll_oracle': 'nll_oracle', - 'nll_div': 'nll_div', - 'temp': 'temp', + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "NLL_div": "NLL_div", + "nll_oracle": "nll_oracle", + "nll_div": "nll_div", + "temp": "temp", } -color_list = ['#2980b9', '#e74c3c', '#1abc9c', '#9b59b6'] -ls_list = ['--', '-'] +color_list = ["#2980b9", "#e74c3c", "#1abc9c", "#9b59b6"] +ls_list = ["--", "-"] marker_list = [None, None] def plt_data(data, length, title, c_id, ls, marker, start=0): x = np.arange(start, start + length, 1) - data = data[start:start + length] + data = data[start : start + length] plt.plot(x, data, color=color_list[c_id], label=title, lw=1.0, ls=ls, marker=marker) if length < 100: plt.xticks(np.arange(start, start + length + 1, 5)) def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': [], 'temp': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": [], "temp": []} for line in all_lines: items = line.split() try: for key in data_dict.keys(): - if '>>>' not in items and key in items: + if ">>>" not in items and key in items: target = items[items.index(key) + 2] - if ',' in target: + if "," in target: target = target[:-1] data_dict[key].append(float(target)) except: @@ -46,13 +46,13 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': - os.chdir('..') - log_file_root = 'savefig/figure_log/' - log_file_list = ['exp_temp5', 'evo_temp5_nll'] - legend_text = ['Exponential temperature', 'Evolutionary temperature'] +if __name__ == "__main__": + os.chdir("..") + log_file_root = "savefig/figure_log/" + log_file_list = ["exp_temp5", "evo_temp5_nll"] + legend_text = ["Exponential temperature", "Evolutionary temperature"] - data_name = 'temp' + data_name = "temp" if_save = True color_id = 0 all_data_list = [] @@ -62,23 +62,30 @@ def get_log_data(filename): plt.clf() if length < 100: plt.figure(figsize=(4, 3)) - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.xticks(fontsize=7) plt.yticks(fontsize=7) for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - plt_data(all_data[title_dict[data_name]], length, legend_text[idx], color_id, start=start, ls=ls_list[idx], - marker=marker_list[idx]) + plt_data( + all_data[title_dict[data_name]], + length, + legend_text[idx], + color_id, + start=start, + ls=ls_list[idx], + marker=marker_list[idx], + ) color_id += 1 if length > 100: - plt.legend(prop={'size': 7}) + plt.legend(prop={"size": 7}) plt.xlabel(r"training iterations", fontsize=7) plt.ylabel(r"temperature", fontsize=7) plt.tight_layout() if if_save: - plt.savefig('savefig/temp_curve_{}.pdf'.format(length)) + plt.savefig("savefig/temp_curve_{}.pdf".format(length)) plt.show() diff --git a/visual/visual_temp_compare.py b/visual/visual_temp_compare.py index a96ca64a..83b2171a 100644 --- a/visual/visual_temp_compare.py +++ b/visual/visual_temp_compare.py @@ -4,15 +4,15 @@ import numpy as np title_dict = { - 'NLL_oracle': 'NLL_oracle', - 'NLL_gen': 'NLL_gen', - 'NLL_div': 'NLL_div', - 'nll_oracle': 'nll_oracle', - 'nll_div': 'nll_div', - 'temp': 'temp', + "NLL_oracle": "NLL_oracle", + "NLL_gen": "NLL_gen", + "NLL_div": "NLL_div", + "nll_oracle": "nll_oracle", + "nll_div": "nll_div", + "temp": "temp", } -color_list = ['#e74c3c', '#f1c40f', '#1abc9c', '#9b59b6'] +color_list = ["#e74c3c", "#f1c40f", "#1abc9c", "#9b59b6"] def plt_data(data, title, c_id): @@ -25,17 +25,17 @@ def plt_data(data, title, c_id): def get_log_data(filename): - with open(filename, 'r') as fin: - all_lines = fin.read().strip().split('\n') - data_dict = {'NLL_oracle': [], 'NLL_gen': [], 'NLL_div': [], 'temp': []} + with open(filename, "r") as fin: + all_lines = fin.read().strip().split("\n") + data_dict = {"NLL_oracle": [], "NLL_gen": [], "NLL_div": [], "temp": []} for line in all_lines: items = line.split() try: for key in data_dict.keys(): - if '>>>' not in items and key in items: + if ">>>" not in items and key in items: target = items[items.index(key) + 2] - if ',' in target: + if "," in target: target = target[:-1] data_dict[key].append(float(target)) except: @@ -44,40 +44,49 @@ def get_log_data(filename): return data_dict -if __name__ == '__main__': +if __name__ == "__main__": # log_file_root = '../log/' - log_file_root = 'savefig/figure_log/' - log_file_list = ['catgan_temp1_final', 'catgan_temp5_final', 'relgan_temp1_final', 'relgan_temp5_final'] - legend_text = [r'CatGAN ($\tau_{\rm{tar}}$=1)', r'CatGAN ($\tau_{\rm{tar}}$=5)', r'RelGAN ($\tau_{\rm{tar}}$=1)', - r'RelGAN ($\tau_{\rm{tar}}$=5)'] - - data_name_list = ['NLL_oracle', 'NLL_div'] + log_file_root = "savefig/figure_log/" + log_file_list = [ + "catgan_temp1_final", + "catgan_temp5_final", + "relgan_temp1_final", + "relgan_temp5_final", + ] + legend_text = [ + r"CatGAN ($\tau_{\rm{tar}}$=1)", + r"CatGAN ($\tau_{\rm{tar}}$=5)", + r"RelGAN ($\tau_{\rm{tar}}$=1)", + r"RelGAN ($\tau_{\rm{tar}}$=5)", + ] + + data_name_list = ["NLL_oracle", "NLL_div"] if_save = True plt.clf() plt.figure(figsize=(8, 3.5)) for cur_id, data_name in enumerate(data_name_list): - assert data_name in title_dict.keys(), 'Error data name' + assert data_name in title_dict.keys(), "Error data name" plt.subplot(12 * 10 + cur_id + 1) if cur_id == 0: # plt.title(r"$\rm{NLL}_{\rm{oracle}}$") plt.ylabel(r"$\rm{NLL}_{\rm{oracle}}$", fontsize=12) - plt.plot([150, 150], [8.3, 9.4], 'k--') + plt.plot([150, 150], [8.3, 9.4], "k--") else: # plt.title(r"$\rm{NLL}_{\rm{div}}$") plt.ylabel(r"$\rm{NLL}_{\rm{div}}$", fontsize=12) - plt.plot([150, 150], [3.3, 5], 'k--') + plt.plot([150, 150], [3.3, 5], "k--") plt.xlabel("training iterations", fontsize=12) color_id = 0 all_data_list = [] for idx, item in enumerate(log_file_list): - log_file = log_file_root + item + '.txt' + log_file = log_file_root + item + ".txt" # save log file all_data = get_log_data(log_file) - if 'catgan' in log_file or 'relgan' in log_file: + if "catgan" in log_file or "relgan" in log_file: temp = all_data[title_dict[data_name]] last = list(np.array(temp)[range(15, 108, 2)]) res = temp[:15] + last @@ -89,5 +98,5 @@ def get_log_data(filename): plt.legend() plt.tight_layout() if if_save: - plt.savefig('savefig/temp_figure.pdf') + plt.savefig("savefig/temp_figure.pdf") plt.show() From 0c413144eed9ec32a93feefe7345df32872e77bb Mon Sep 17 00:00:00 2001 From: salaxieb Date: Thu, 6 Jul 2023 14:59:05 +0100 Subject: [PATCH 81/81] updated README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7513d7e4..9d1c179c 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ To install, run `pip install -r requirements.txt`. In case of CUDA problems, con ### General Text Generation -- **FixemGAN** - [FixemGAN: Continious Space Text GAN on Fixed Embeddings](https://www.com) +- **FixemGAN** - [FixemGAN: Continious Space Text GAN on Fixed Embeddings](https://medium.com/@salaxieb.ildar/text-gan-on-embeddings-debb9a006fff) - **SeqGAN** - [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/abs/1609.05473) - **LeakGAN** - [Long Text Generation via Adversarial Training with Leaked Information](https://arxiv.org/abs/1709.08624) - **MaliGAN** - [Maximum-Likelihood Augmented Discrete Generative Adversarial Networks](https://arxiv.org/abs/1702.07983)