diff --git a/src/main/scala/BIDMach/models/FM.scala b/src/main/scala/BIDMach/models/FM.scala index 2be63c6a..ef749868 100755 --- a/src/main/scala/BIDMach/models/FM.scala +++ b/src/main/scala/BIDMach/models/FM.scala @@ -174,6 +174,23 @@ object FM { } def learner(mat0:Mat, targ:Mat):(Learner, LearnOptions) = learner(mat0, targ, 0) + + + def predictor(model:Model, mat1:Mat, preds:Mat, d:Int):(Learner, LearnOptions) = { + val nopts = new LearnOptions; + nopts.batchSize = math.min(10000, mat1.ncols/30 + 1) + if (nopts.links == null) nopts.links = izeros(preds.nrows,1) + nopts.links.set(d) + nopts.putBack = 1 + val nn = new Learner( + new MatDS(Array(mat1, preds), nopts), + model.asInstanceOf[FM], + null, + null, + nopts) + (nn, nopts) +} + def learnBatch(mat0:Mat, d:Int) = { val opts = new LearnOptions