From 2078c667a40e1814df206d72bedfb1980afa4a8d Mon Sep 17 00:00:00 2001 From: tomjur Date: Wed, 8 Nov 2017 10:25:15 +0200 Subject: [PATCH] different learn rates for each component --- config/gan.yml | 5 +++-- v1_embedding/gan_model.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/config/gan.yml b/config/gan.yml index 631c936..f8d209b 100644 --- a/config/gan.yml +++ b/config/gan.yml @@ -13,7 +13,8 @@ model: # encoder_hidden_states: [100] decoder_hidden_states: [1000, 1500] # decoder_hidden_states: [50] - learn_rate: 0.0035 + generator_learn_rate: 0.0035 + discriminator_learn_rate: 0.0035 # optimizer: 'gd' # optimizer: 'adam' optimizer: 'rmsp' @@ -60,6 +61,6 @@ embedding: sentence: limit: 300000 # limit: 150 - min_length: 6 + min_length: 15 # min_length: 3 max_length: 15 diff --git a/v1_embedding/gan_model.py b/v1_embedding/gan_model.py index 4b54a21..6fd1b03 100644 --- a/v1_embedding/gan_model.py +++ b/v1_embedding/gan_model.py @@ -163,8 +163,9 @@ def _init_discriminator(self): self.config['discriminator_content']['hidden_states'], self.discriminator_dropout_placeholder) - def _get_optimizer(self): - learn_rate = self.config['model']['learn_rate'] + def _get_optimizer(self, is_generator): + learn_rate = self.config['model']['generator_learn_rate'] if is_generator else self.config['model'][ + 'discriminator_learn_rate'] if self.config['model']['optimizer'] == 'gd': return tf.train.GradientDescentOptimizer(learn_rate) if self.config['model']['optimizer'] == 'adam': @@ -215,7 +216,7 @@ def _get_discriminator_train_step(self): update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.variable_scope('TrainDiscriminatorSteps'): - discriminator_optimizer = self._get_optimizer() + discriminator_optimizer = self._get_optimizer(False) discriminator_var_list = self.discriminator.get_trainable_parameters() discriminator_grads_and_vars = discriminator_optimizer.compute_gradients( @@ -233,7 +234,7 @@ def _get_discriminator_train_step(self): def _get_generator_train_step(self): with tf.variable_scope('TrainGeneratorSteps'): - generator_optimizer = self._get_optimizer() + generator_optimizer = self._get_optimizer(True) generator_var_list = self.encoder.get_trainable_parameters() + self.decoder.get_trainable_parameters() + \ self.embedding_container.get_trainable_parameters() generator_grads_and_vars = generator_optimizer.compute_gradients(