Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions brainstorm/training/steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,62 @@ def run(self):
self.net.handler.mult_add_st(-learning_rate,
self.net.buffer.gradients,
self.net.buffer.parameters)

class AdamStepper(TrainingStepper):
"""
Adam optimizer.
Decay rate lamb (lambda) decays beta1, thus slowly increasing momentum.
For more detailed information see "Adam: A Method for Stochastic Optimization" by Kingma and Ba.
"""
__undescribed__ = {'m_0', 'v_0'}

def __init__(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lamb=1-1e-8):
super(AdamStepper, self).__init__()
self.m_0 = None
self.v_0 = None
self.time_step = None
self.alpha = alpha
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lamb = lamb

def start(self, net):
super(AdamStepper, self).start(net)
self.m_0 = net.handler.zeros(net.buffer.parameters.shape)
self.v_0 = net.handler.zeros(net.buffer.parameters.shape)
self.time_step = 0

def run(self):
self.time_step += 1
self.beta1 *= self.lamb
t = self.time_step
learning_rate = self.alpha
self.net.forward_pass(training_pass=True)
self.net.backward_pass()

gradient = self.net.buffer.gradients
temp = self.net.handler.allocate(gradient.shape)
temp_m0 = self.net.handler.allocate(self.m_0.shape)
temp_v0 = self.net.handler.allocate(self.v_0.shape)

# m_t <- beta_1*m_{t-1} + (1-beta1) *gradient
self.net.handler.mult_st(self.beta1, self.m_0, out=self.m_0)
self.net.handler.mult_add_st(1.0-self.beta1, gradient, out=self.m_0)
# v_t <- beta_2*v_{t-1} + (1-beta2) *gradient^2
self.net.handler.mult_st(self.beta2, self.v_0, out=self.v_0)
self.net.handler.mult_tt(gradient, gradient, temp) # gradient^2
self.net.handler.mult_add_st(1.0-self.beta2, temp, out=self.v_0)
# m_hat_t <- m_t/(1-beta1^t)
self.net.handler.mult_st(1.0/(1.0-pow(self.beta1, t)), self.m_0, out=temp_m0)
# v_hat_t <- v_t/(1-beta2^t)
self.net.handler.mult_st(1.0/(1.0-pow(self.beta2, t)), self.v_0, out=temp_v0)

self.net.handler.sqrt_t(temp_v0, temp_v0)
self.net.handler.add_st(self.epsilon, temp_v0, out=temp_v0)

self.net.handler.mult_st(learning_rate, temp_m0, out=temp_m0)

self.net.handler.divide_tt(temp_m0, temp_v0, temp)

self.net.handler.subtract_tt(self.net.buffer.parameters, temp, out=self.net.buffer.parameters)