diff --git a/brainstorm/training/steppers.py b/brainstorm/training/steppers.py index 286e547..13070ed 100644 --- a/brainstorm/training/steppers.py +++ b/brainstorm/training/steppers.py @@ -143,3 +143,62 @@ def run(self): self.net.handler.mult_add_st(-learning_rate, self.net.buffer.gradients, self.net.buffer.parameters) + +class AdamStepper(TrainingStepper): + """ + Adam optimizer. + Decay rate lamb (lambda) decays beta1, thus slowly increasing momentum. + For more detailed information see "Adam: A Method for Stochastic Optimization" by Kingma and Ba. + """ + __undescribed__ = {'m_0', 'v_0'} + + def __init__(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lamb=1-1e-8): + super(AdamStepper, self).__init__() + self.m_0 = None + self.v_0 = None + self.time_step = None + self.alpha = alpha + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.lamb = lamb + + def start(self, net): + super(AdamStepper, self).start(net) + self.m_0 = net.handler.zeros(net.buffer.parameters.shape) + self.v_0 = net.handler.zeros(net.buffer.parameters.shape) + self.time_step = 0 + + def run(self): + self.time_step += 1 + self.beta1 *= self.lamb + t = self.time_step + learning_rate = self.alpha + self.net.forward_pass(training_pass=True) + self.net.backward_pass() + + gradient = self.net.buffer.gradients + temp = self.net.handler.allocate(gradient.shape) + temp_m0 = self.net.handler.allocate(self.m_0.shape) + temp_v0 = self.net.handler.allocate(self.v_0.shape) + + # m_t <- beta_1*m_{t-1} + (1-beta1) *gradient + self.net.handler.mult_st(self.beta1, self.m_0, out=self.m_0) + self.net.handler.mult_add_st(1.0-self.beta1, gradient, out=self.m_0) + # v_t <- beta_2*v_{t-1} + (1-beta2) *gradient^2 + self.net.handler.mult_st(self.beta2, self.v_0, out=self.v_0) + self.net.handler.mult_tt(gradient, gradient, temp) # gradient^2 + self.net.handler.mult_add_st(1.0-self.beta2, temp, out=self.v_0) + # m_hat_t <- m_t/(1-beta1^t) + self.net.handler.mult_st(1.0/(1.0-pow(self.beta1, t)), self.m_0, out=temp_m0) + # v_hat_t <- v_t/(1-beta2^t) + self.net.handler.mult_st(1.0/(1.0-pow(self.beta2, t)), self.v_0, out=temp_v0) + + self.net.handler.sqrt_t(temp_v0, temp_v0) + self.net.handler.add_st(self.epsilon, temp_v0, out=temp_v0) + + self.net.handler.mult_st(learning_rate, temp_m0, out=temp_m0) + + self.net.handler.divide_tt(temp_m0, temp_v0, temp) + + self.net.handler.subtract_tt(self.net.buffer.parameters, temp, out=self.net.buffer.parameters)