diff --git a/src/main/scala/BIDMach/Learner.scala b/src/main/scala/BIDMach/Learner.scala index 5648164d..93a85ac3 100755 --- a/src/main/scala/BIDMach/Learner.scala +++ b/src/main/scala/BIDMach/Learner.scala @@ -135,6 +135,9 @@ case class Learner( if (opts.updateAll) { model.dobatchg(mats, ipass, here); if (mixins != null) mixins map (_ compute(mats, here)); + if (opts.observer != null) { + opts.observer.notify(ipass, model, mats) + } if (updater != null) updater.update(ipass, here, gprogress); } val scores = model.evalbatchg(mats, ipass, here); @@ -144,6 +147,9 @@ case class Learner( } else { model.dobatchg(mats, ipass, here) if (mixins != null) mixins map (_ compute(mats, here)) + if (opts.observer != null) { + opts.observer.notify(ipass, model, mats) + } if (updater != null) updater.update(ipass, here, gprogress) } if (datasource.opts.putBack >= 0) datasource.putBack(mats, datasource.opts.putBack) @@ -814,6 +820,11 @@ class ParLearnerF( } object Learner { + trait LearnerObserver { + def init = {} + def cleanup = {} + def notify(ipass:Int, model:Model, minibatch:Array[Mat]) = {} + } class Options extends BIDMat.Opts { var npasses = 2; @@ -827,6 +838,7 @@ object Learner { var cumScore = 0; var checkPointFile:String = null; var checkPointInterval = 0f; + var observer: LearnerObserver = null; } def numBytes(mat:Mat):Long = {