From 539e40ef19050d0ab052a0b7035e48539822e342 Mon Sep 17 00:00:00 2001 From: keorn Date: Mon, 5 Oct 2015 14:48:37 +0100 Subject: [PATCH 1/7] functions for continous labels, and generalised label translation --- src/synaptic/datasets.clj | 38 ++++++++++++++++----------------- src/synaptic/net.clj | 2 +- src/synaptic/util.clj | 21 ++++++++++++++++-- test/synaptic/datasets_test.clj | 12 +++++------ test/synaptic/util_test.clj | 9 +++++++- 5 files changed, 53 insertions(+), 29 deletions(-) diff --git a/src/synaptic/datasets.clj b/src/synaptic/datasets.clj index 0fd3ace..8e59c5d 100644 --- a/src/synaptic/datasets.clj +++ b/src/synaptic/datasets.clj @@ -2,7 +2,8 @@ ^{:doc "synaptic - data sets" :author "Antoine Choppin"} synaptic.datasets - (:require [clatrix.core :as m] + (:require [clojure.set :refer [rename-keys]] + [clatrix.core :as m] [synaptic.util :as u]) (:gen-class)) @@ -80,10 +81,8 @@ (defn count-labels "Create a map with number of occurrence of each label." - [uniquelabels binlabels] - (let [binlb2cnt (reduce (fn [lbmap lb] (assoc lbmap lb (inc (get lbmap lb 0)))) - {} binlabels)] - (zipmap (u/frombinary uniquelabels (keys binlb2cnt)) (vals binlb2cnt)))) + [uniquelabelmap encodedlabels] + (rename-keys (frequencies encodedlabels) uniquelabelmap)) (defn training-set "Create a training set from samples and associated labels. @@ -91,22 +90,23 @@ It also has a map that will allow converting y's back to the original labels. Options: - :name - a name for the training set - :type - the type of training data (e.g. :binary-image, :grayscale-image ...) - :fieldsize - [width height] of each sample data (for images) - :nvalid - size of the validation set (default is 0, i.e. no validation set) - :batch - size of a mini-batch (default is the number of samples, after - having set apart the validation set) - :online true - set this flag for online training (same as batch size = 1) - :rand false - unset this flag to keep original ordering (by default, samples - will be shuffled before partitioning)." + :name - a name for the training set + :type - the type of training data (e.g. :binary-image, :grayscale-image ...) + :continous true - set this flag to use continous labels (auto-scaled to between 0 and 1) + :fieldsize - [width height] of each sample data (for images) + :nvalid - size of the validation set (default is 0, i.e. no validation set) + :batch - size of a mini-batch (default is the number of samples, after + having set apart the validation set) + :online true - set this flag for online training (same as batch size = 1) + :rand false - unset this flag to keep original ordering (by default, samples + will be shuffled before partitioning)." [samples labels & [options]] {:pre [(= (count samples) (count labels))]} (let [batchsize (if (:online options) 1 (:batch options)) trainsize (if (:nvalid options) (- (count samples) (:nvalid options))) randomize (if (nil? (:rand options)) true (:rand options)) - [binlb uniquelb] (u/tobinary labels) - [smp lb] (if randomize (shuffle-vecs samples binlb) [samples binlb]) + [reglb uniquelbmap] (if (:continous options) (u/tocontinous labels) (u/tobinary labels)) + [smp lb] (if randomize (shuffle-vecs samples reglb) [samples reglb]) [trainsmp validsmp] (if trainsize (split-at trainsize smp) [smp nil]) [trainlb validlb] (if trainsize (split-at trainsize lb) [lb nil]) [batchsmp batchlb] (partition-vecs batchsize trainsmp trainlb) @@ -118,9 +118,9 @@ :type (:type options) :fieldsize (or (:fieldsize options) (u/divisors (count (first samples)))) - :batches (mapv (partial count-labels uniquelb) batchlb) - :valid (count-labels uniquelb validlb) - :labels uniquelb}] + :batches (mapv (partial count-labels uniquelbmap) batchlb) + :valid (count-labels uniquelbmap validlb) + :labelmap uniquelbmap}] (TrainingSet. header trainsets validset))) (defn test-set diff --git a/src/synaptic/net.clj b/src/synaptic/net.clj index 50476a9..214a555 100644 --- a/src/synaptic/net.clj +++ b/src/synaptic/net.clj @@ -264,7 +264,7 @@ y (m/dense (:a (last (net-activities nn x)))) n (count (first y)) ci (mapv #(apply max-key % (range n)) y) - cs (-> nn :header :labels)] + cs (vec (keys (-> nn :header :labelmap)))] (if cs (mapv #(get cs %) ci) ci))) diff --git a/src/synaptic/util.clj b/src/synaptic/util.clj index b92dce3..387e40c 100644 --- a/src/synaptic/util.clj +++ b/src/synaptic/util.clj @@ -146,13 +146,13 @@ (vec (for [i (range n)] (assoc (vec (repeat n 0)) i 1)))) (defn tobinary - "Encode labels to a vector with 0 and 1. Also returns the vector of + "Encode labels to a vector with 0 and 1. Also returns the map of unique labels to decode them." [labels] (let [uniquelabels (unique labels) lbcodes (bincodes (count uniquelabels)) lb2code (zipmap uniquelabels lbcodes)] - [(mapv lb2code labels) uniquelabels])) + [(mapv lb2code labels) (zipmap lbcodes uniquelabels)])) (defn frombinary "Decode a vector of 0 and 1 to the original label, based on a vector @@ -162,6 +162,23 @@ code2lb (zipmap lbcodes uniquelabels)] (mapv code2lb encodedlabels))) +; Continous scaling + +(defn scale-labels + "Scale set of labels to range 0 to 1." + [mat] + (let [smallest-element (apply min (flatten mat)) + largest-element (apply max (flatten mat)) + shifted-representation (m/- (m/matrix mat) smallest-element)] + (m/to-vecs (m/div shifted-representation (m/- largest-element smallest-element))))) + +(defn tocontinous + "Encode labels to a vector with numbers in range 0 to 1. Also returns the map of + unique labels to decode them." + [labels] + (let [uniquelabels (unique labels)] + [(scale-labels labels) (zipmap (scale-labels uniquelabels) uniquelabels)])) + ; Make clatrix matrices printable and readable in EDN format (defmethod print-method diff --git a/test/synaptic/datasets_test.clj b/test/synaptic/datasets_test.clj index fce0b8b..abea16f 100644 --- a/test/synaptic/datasets_test.clj +++ b/test/synaptic/datasets_test.clj @@ -31,12 +31,12 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labels)] + ulbs (-> ts :header :labelmap)] (is (vector? bs)) (is (= 5 (count bs))) (is (every? #(= DataSet (type %)) bs)) (is (nil? vs)) - (is (= ["a" "b"] ulbs)) + (is (= {[0 1] "b", [1 0] "a"} ulbs)) (let [x (:x (first bs)) y (:y (first bs))] (is (m/matrix? x)) @@ -49,12 +49,12 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labels)] + ulbs (-> ts :header :labelmap)] (is (vector? bs)) (is (= 1 (count bs))) (is (= DataSet (type (first bs)))) (is (= DataSet (type vs))) - (is (= ["0" "1" "2" "3"] ulbs)) + (is (= {[0 0 0 1] "3", [0 0 1 0] "2", [0 1 0 0] "1", [1 0 0 0] "0"} ulbs)) (let [x (:x (first bs)) y (:y (first bs))] (is (m/matrix? x)) @@ -72,12 +72,12 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labels)] + ulbs (-> ts :header :labelmap)] (is (vector? bs)) (is (= 5 (count bs))) (is (= DataSet (type (first bs)))) (is (nil? vs)) - (is (= ["+" "-"] ulbs)) + (is (= {[0 1] "-", [1 0] "+"} ulbs)) (is (every? true? (map #(= [(map double %1)] (m/dense (:x %2))) smp bs))))))) diff --git a/test/synaptic/util_test.clj b/test/synaptic/util_test.clj index 5380a94..42dc777 100644 --- a/test/synaptic/util_test.clj +++ b/test/synaptic/util_test.clj @@ -104,7 +104,7 @@ (testing "tobinary should return the vector of unique labels and all labels encoded to binary vectors" (is (= [[[0 0 0 1] [1 0 0 0] [0 1 0 0] [0 0 0 1] [0 0 1 0] [1 0 0 0] [0 0 1 0]] - ["1" "2" "3" "8"]] + {[0 0 0 1] "8", [0 0 1 0] "3", [0 1 0 0] "2", [1 0 0 0] "1"}] (tobinary ["8" "1" "2" "8" "3" "1" "3"])))) (testing "frombinary should decode each label to its original value, based on a vector of unique labels" @@ -113,6 +113,13 @@ [[0 0 0 1] [1 0 0 0] [0 1 0 0] [0 0 0 1] [0 0 1 0] [1 0 0 0] [0 0 1 0]]))))) +(deftest test-continous-scaling + (testing "tocontinous should return the vector of unique labels and all labels + scaled to vectors with values in range 0 to 1" + (is (= [[[0.4 0.6] [0.8 1.0] [0.0 0.2]] + {[0.8 1.0] [3 4], [0.4 0.6] [1 2], [0.0 0.2] [-1 0]}] + (tocontinous [[1 2] [3 4] [-1 0]]))))) + (deftest test-data-manipulation (testing "unique should return a sorted vector of unique values" (is (= ["a" "b" "c" "d" "x" "y" "z"] From 1b470682217b706c003268b7eabb411450eec6c0 Mon Sep 17 00:00:00 2001 From: keorn Date: Tue, 6 Oct 2015 16:32:50 +0100 Subject: [PATCH 2/7] fix new tests, make functions use :labeltranslator --- src/synaptic/datasets.clj | 17 +++++++++-------- src/synaptic/net.clj | 22 ++++++++++------------ src/synaptic/util.clj | 20 ++++++++------------ test/synaptic/datasets_test.clj | 6 +++--- test/synaptic/util_test.clj | 9 +++++---- 5 files changed, 35 insertions(+), 39 deletions(-) diff --git a/src/synaptic/datasets.clj b/src/synaptic/datasets.clj index 8e59c5d..6b08661 100644 --- a/src/synaptic/datasets.clj +++ b/src/synaptic/datasets.clj @@ -2,8 +2,7 @@ ^{:doc "synaptic - data sets" :author "Antoine Choppin"} synaptic.datasets - (:require [clojure.set :refer [rename-keys]] - [clatrix.core :as m] + (:require [clatrix.core :as m] [synaptic.util :as u]) (:gen-class)) @@ -81,8 +80,10 @@ (defn count-labels "Create a map with number of occurrence of each label." - [uniquelabelmap encodedlabels] - (rename-keys (frequencies encodedlabels) uniquelabelmap)) + [labeltranslator encodedlabels] + (let [translate-keys #(zipmap (mapv labeltranslator (keys %)) + (vals %))] + (translate-keys (frequencies encodedlabels)))) (defn training-set "Create a training set from samples and associated labels. @@ -105,7 +106,7 @@ (let [batchsize (if (:online options) 1 (:batch options)) trainsize (if (:nvalid options) (- (count samples) (:nvalid options))) randomize (if (nil? (:rand options)) true (:rand options)) - [reglb uniquelbmap] (if (:continous options) (u/tocontinous labels) (u/tobinary labels)) + [reglb lbtranslator] (if (:continous options) (u/tocontinous labels) (u/tobinary labels)) [smp lb] (if randomize (shuffle-vecs samples reglb) [samples reglb]) [trainsmp validsmp] (if trainsize (split-at trainsize smp) [smp nil]) [trainlb validlb] (if trainsize (split-at trainsize lb) [lb nil]) @@ -118,9 +119,9 @@ :type (:type options) :fieldsize (or (:fieldsize options) (u/divisors (count (first samples)))) - :batches (mapv (partial count-labels uniquelbmap) batchlb) - :valid (count-labels uniquelbmap validlb) - :labelmap uniquelbmap}] + :batches (mapv (partial count-labels lbtranslator) batchlb) + :valid (count-labels lbtranslator validlb) + :labeltranslator lbtranslator}] (TrainingSet. header trainsets validset))) (defn test-set diff --git a/src/synaptic/net.clj b/src/synaptic/net.clj index 214a555..0b53ad0 100644 --- a/src/synaptic/net.clj +++ b/src/synaptic/net.clj @@ -256,16 +256,14 @@ as)))) (defn estimate - "Estimate classes for a given data set, by computing network output for each - sample of the data set, and returns the most probable class (label) - or its - index if labels are not defined." + "Estimate labels for a given data set, by computing network output for each + sample of the data set, and returns appropriately transformed result + - or its index if labels are not defined." [^Net nn ^DataSet dset] - (let [x (:x dset) - y (m/dense (:a (last (net-activities nn x)))) - n (count (first y)) - ci (mapv #(apply max-key % (range n)) y) - cs (vec (keys (-> nn :header :labelmap)))] - (if cs - (mapv #(get cs %) ci) - ci))) - + (let [x (:x dset) + y (m/dense (:a (last (net-activities nn x)))) + label-size (count (first y)) + lbtranslator (-> dset :header :labeltranslator)] + (if lbtranslator + (mapv lbtranslator y) + (mapv #(apply max-key % (range label-size)) y)))) diff --git a/src/synaptic/util.clj b/src/synaptic/util.clj index 387e40c..8a1101a 100644 --- a/src/synaptic/util.clj +++ b/src/synaptic/util.clj @@ -164,20 +164,16 @@ ; Continous scaling -(defn scale-labels - "Scale set of labels to range 0 to 1." - [mat] - (let [smallest-element (apply min (flatten mat)) - largest-element (apply max (flatten mat)) - shifted-representation (m/- (m/matrix mat) smallest-element)] - (m/to-vecs (m/div shifted-representation (m/- largest-element smallest-element))))) - (defn tocontinous - "Encode labels to a vector with numbers in range 0 to 1. Also returns the map of - unique labels to decode them." + "Encode labels to vectors with numbers in range 0 to 1 + and a function to decode them." [labels] - (let [uniquelabels (unique labels)] - [(scale-labels labels) (zipmap (scale-labels uniquelabels) uniquelabels)])) + (let [smallest-element (apply min (flatten labels)) + largest-element (apply max (flatten labels)) + scaling-factor (m/- largest-element smallest-element) + shifted-representation (m/- (m/matrix labels) smallest-element)] + [(m/to-vecs (m/div shifted-representation scaling-factor)) + #(m/+ smallest-element (m/mult (m/matrix %) scaling-factor))])) ; Make clatrix matrices printable and readable in EDN format diff --git a/test/synaptic/datasets_test.clj b/test/synaptic/datasets_test.clj index abea16f..79f1e47 100644 --- a/test/synaptic/datasets_test.clj +++ b/test/synaptic/datasets_test.clj @@ -31,7 +31,7 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labelmap)] + ulbs (-> ts :header :labeltranslator)] (is (vector? bs)) (is (= 5 (count bs))) (is (every? #(= DataSet (type %)) bs)) @@ -49,7 +49,7 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labelmap)] + ulbs (-> ts :header :labeltranslator)] (is (vector? bs)) (is (= 1 (count bs))) (is (= DataSet (type (first bs)))) @@ -72,7 +72,7 @@ (is (= TrainingSet (type ts))) (let [bs (:batches ts) vs (:valid ts) - ulbs (-> ts :header :labelmap)] + ulbs (-> ts :header :labeltranslator)] (is (vector? bs)) (is (= 5 (count bs))) (is (= DataSet (type (first bs)))) diff --git a/test/synaptic/util_test.clj b/test/synaptic/util_test.clj index 42dc777..4bc6c4a 100644 --- a/test/synaptic/util_test.clj +++ b/test/synaptic/util_test.clj @@ -115,10 +115,11 @@ (deftest test-continous-scaling (testing "tocontinous should return the vector of unique labels and all labels - scaled to vectors with values in range 0 to 1" - (is (= [[[0.4 0.6] [0.8 1.0] [0.0 0.2]] - {[0.8 1.0] [3 4], [0.4 0.6] [1 2], [0.0 0.2] [-1 0]}] - (tocontinous [[1 2] [3 4] [-1 0]]))))) + scaled to vectors with values in range 0 to 1, and a function to scale them back" + (is (= [[0.4 0.6] [0.8 1.0] [0.0 0.2]] + (first (tocontinous [[1 2] [3 4] [-1 0]])))) + (is (= (m/matrix [[-1 4]]) + ((second (tocontinous [[1 2] [3 4] [-1 0]])) [[0 1]]))))) (deftest test-data-manipulation (testing "unique should return a sorted vector of unique values" From 0988df0562e34f601a6899f94f1370fdbdfcc819 Mon Sep 17 00:00:00 2001 From: keorn Date: Tue, 6 Oct 2015 18:11:13 +0100 Subject: [PATCH 3/7] make estimate function work for all label types --- src/synaptic/net.clj | 2 +- src/synaptic/training.clj | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/synaptic/net.clj b/src/synaptic/net.clj index 0b53ad0..66d9034 100644 --- a/src/synaptic/net.clj +++ b/src/synaptic/net.clj @@ -263,7 +263,7 @@ (let [x (:x dset) y (m/dense (:a (last (net-activities nn x)))) label-size (count (first y)) - lbtranslator (-> dset :header :labeltranslator)] + lbtranslator (-> nn :arch :labeltranslator)] (if lbtranslator (mapv lbtranslator y) (mapv #(apply max-key % (range label-size)) y)))) diff --git a/src/synaptic/training.clj b/src/synaptic/training.clj index e1bd3d6..cefbfc1 100644 --- a/src/synaptic/training.clj +++ b/src/synaptic/training.clj @@ -623,12 +623,18 @@ as the training progresses." (fn [net _ _] (-> @net :training :algo))) +(defn initialize-train + "First step in train procedure" + [net ^TrainingSet trset] + (swap! net assoc-in [:training :state :state] :training) + (swap! net init-stats) + (swap! net assoc-in [:arch :labeltranslator] (-> trset :header :labeltranslator))) + (defmethod train :lbfgs [net ^TrainingSet trset nepochs] (future - (swap! net assoc-in [:training :state :state] :training) - (swap! net init-stats) + (initialize-train net trset) (let [l (-> @net :arch :layers) b (d/merge-batches (:batches trset)) w0 (weights-to-double-array (:weights @net)) @@ -654,8 +660,7 @@ :default [net ^TrainingSet trset nepochs] (future - (swap! net assoc-in [:training :state :state] :training) - (swap! net init-stats) + (initialize-train net trset) (let [maxep (+ nepochs (-> @net :training :stats :epochs)) all-batches (:batches trset)] (loop [batches all-batches] From 059e119ea0b11874d33b08119550e019b83dab64 Mon Sep 17 00:00:00 2001 From: keorn Date: Tue, 6 Oct 2015 18:43:35 +0100 Subject: [PATCH 4/7] fix spelling --- src/synaptic/datasets.clj | 4 ++-- src/synaptic/util.clj | 6 +++--- test/synaptic/util_test.clj | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/synaptic/datasets.clj b/src/synaptic/datasets.clj index 6b08661..a03e853 100644 --- a/src/synaptic/datasets.clj +++ b/src/synaptic/datasets.clj @@ -93,7 +93,7 @@ Options: :name - a name for the training set :type - the type of training data (e.g. :binary-image, :grayscale-image ...) - :continous true - set this flag to use continous labels (auto-scaled to between 0 and 1) + :continuous true - set this flag to use continuous labels (auto-scaled to between 0 and 1) :fieldsize - [width height] of each sample data (for images) :nvalid - size of the validation set (default is 0, i.e. no validation set) :batch - size of a mini-batch (default is the number of samples, after @@ -106,7 +106,7 @@ (let [batchsize (if (:online options) 1 (:batch options)) trainsize (if (:nvalid options) (- (count samples) (:nvalid options))) randomize (if (nil? (:rand options)) true (:rand options)) - [reglb lbtranslator] (if (:continous options) (u/tocontinous labels) (u/tobinary labels)) + [reglb lbtranslator] (if (:continuous options) (u/tocontinuous labels) (u/tobinary labels)) [smp lb] (if randomize (shuffle-vecs samples reglb) [samples reglb]) [trainsmp validsmp] (if trainsize (split-at trainsize smp) [smp nil]) [trainlb validlb] (if trainsize (split-at trainsize lb) [lb nil]) diff --git a/src/synaptic/util.clj b/src/synaptic/util.clj index 8a1101a..4927099 100644 --- a/src/synaptic/util.clj +++ b/src/synaptic/util.clj @@ -162,11 +162,11 @@ code2lb (zipmap lbcodes uniquelabels)] (mapv code2lb encodedlabels))) -; Continous scaling +; continuous scaling -(defn tocontinous +(defn tocontinuous "Encode labels to vectors with numbers in range 0 to 1 - and a function to decode them." + and return a function to decode them." [labels] (let [smallest-element (apply min (flatten labels)) largest-element (apply max (flatten labels)) diff --git a/test/synaptic/util_test.clj b/test/synaptic/util_test.clj index 4bc6c4a..e111945 100644 --- a/test/synaptic/util_test.clj +++ b/test/synaptic/util_test.clj @@ -104,7 +104,7 @@ (testing "tobinary should return the vector of unique labels and all labels encoded to binary vectors" (is (= [[[0 0 0 1] [1 0 0 0] [0 1 0 0] [0 0 0 1] [0 0 1 0] [1 0 0 0] [0 0 1 0]] - {[0 0 0 1] "8", [0 0 1 0] "3", [0 1 0 0] "2", [1 0 0 0] "1"}] + {[0 0 0 1] "8", [0 0 1 0] "3", [0 1 0 0] "2", [1 0 0 0] "1"}] (tobinary ["8" "1" "2" "8" "3" "1" "3"])))) (testing "frombinary should decode each label to its original value, based on a vector of unique labels" @@ -113,13 +113,13 @@ [[0 0 0 1] [1 0 0 0] [0 1 0 0] [0 0 0 1] [0 0 1 0] [1 0 0 0] [0 0 1 0]]))))) -(deftest test-continous-scaling - (testing "tocontinous should return the vector of unique labels and all labels +(deftest test-continuous-scaling + (testing "tocontinuous should return the vector of unique labels and all labels scaled to vectors with values in range 0 to 1, and a function to scale them back" - (is (= [[0.4 0.6] [0.8 1.0] [0.0 0.2]] - (first (tocontinous [[1 2] [3 4] [-1 0]])))) + (is (= [[0.4 0.6] [0.8 1.0] [0.0 0.2]] + (first (tocontinuous [[1 2] [3 4] [-1 0]])))) (is (= (m/matrix [[-1 4]]) - ((second (tocontinous [[1 2] [3 4] [-1 0]])) [[0 1]]))))) + ((second (tocontinuous [[1 2] [3 4] [-1 0]])) [[0 1]]))))) (deftest test-data-manipulation (testing "unique should return a sorted vector of unique values" From c6eda2f79a2ab12b0e7e67f761c65a99ea2321e2 Mon Sep 17 00:00:00 2001 From: keorn Date: Wed, 7 Oct 2015 18:46:00 +0100 Subject: [PATCH 5/7] change name, extend estimate functionality --- project.clj | 2 +- src/synaptic/datasets.clj | 2 +- src/synaptic/net.clj | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/project.clj b/project.clj index 80c8f7d..4ba06c2 100644 --- a/project.clj +++ b/project.clj @@ -1,4 +1,4 @@ -(defproject synaptic "0.3.0-SNAPSHOT" +(defproject keorn/synaptic "0.3.0-SNAPSHOT" :description "Neural Networks in Clojure" :url "https://github.com/japonophile/synaptic" :license {:name "Eclipse Public License" diff --git a/src/synaptic/datasets.clj b/src/synaptic/datasets.clj index a03e853..31493d2 100644 --- a/src/synaptic/datasets.clj +++ b/src/synaptic/datasets.clj @@ -130,7 +130,7 @@ Options: :name - a name for the test set - :type - the type of training data (e.g. :binary-image, :grayscale-image ...) + :type - the type of test data (e.g. :binary-image, :grayscale-image ...) :fieldsize - [width height] of each sample data (for images) :rand true - set this flag to shuffle samples." [samples & [options]] diff --git a/src/synaptic/net.clj b/src/synaptic/net.clj index 66d9034..7e8d42b 100644 --- a/src/synaptic/net.clj +++ b/src/synaptic/net.clj @@ -259,8 +259,8 @@ "Estimate labels for a given data set, by computing network output for each sample of the data set, and returns appropriately transformed result - or its index if labels are not defined." - [^Net nn ^DataSet dset] - (let [x (:x dset) + [^Net nn dset] + (let [x (if (contains? dset :x) (:x dset) dset) y (m/dense (:a (last (net-activities nn x)))) label-size (count (first y)) lbtranslator (-> nn :arch :labeltranslator)] From 72c3d404825468a86c56efe66b7b22a1bc269c28 Mon Sep 17 00:00:00 2001 From: keorn Date: Tue, 2 Feb 2016 13:06:42 +0000 Subject: [PATCH 6/7] add relu --- .swp | Bin 0 -> 40960 bytes src/synaptic/core.clj | 3 +++ src/synaptic/net.clj | 9 +++++++-- test/synaptic/net_test.clj | 3 +++ 4 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 .swp diff --git a/.swp b/.swp new file mode 100644 index 0000000000000000000000000000000000000000..2c83b2b90800df32823930657dfdd6141be749d5 GIT binary patch literal 40960 zcmeI536x}4dEXz1RRKf{MmP!KsYmg24_#HWNJx~xW;Bb`LbK4*Vt|oQy{>xQUDH+d zih5Pu(-K17I1Vuo_H1xs12HCe3&y|+2#g?!S+W=rV+hy*+sAQojED)?IL74n|L(o- zz3Lfh3?|_mr(60@cfIxA?|%2&@2%y%&s#VaUsJy-;^*=D$Njvm8%67BHX4Kj5AwU0 zwUgdjmX9Cgnw?WS6xgA__nHEazU9FFeS6i4T{Bn2Px!uDzSj-fdH)Utb||nzfgK9$ zP+*4wI~3TVzzzj=DDeM~0^{MMqPuzDL;X9K{qKit`+bT3yy|~`-8b0)H(jkl>I%ho0RpN16P8@c-X^ zSrm=HE-(Ww1^?l3QS|fRrQjsE68!pOqv##rC&4;+3b_96mu1L~j#zHn(2eIEP<_)p+Rz}?_DI0l{$E&~q*zl%eEFL)I=2L3L%k08J+ zz)yh%a4Gn=1P*=;d&>gI%gSb1$$4Rf3cDO1Y zrlWYgmcjBUPvTQw0#_?LV84ulM zd7O;WxUJjMPCbfdSF%x@o=f_}UOFGQM(MEEiaFF;OMAoCT-;6uanj4Pc$vFgQ&vV< zKTZZ)ai=@t#o1`99v!3Lq|)wjlCS4+Ht21|eKjRd2YE_w)c&}g4aRC}wn7tqv+D7w ztWVdHK_~8Kqm-7#d*ah;iS>q?`U!Url75;GlXj|m^{8Yxbl0VHYMjL@UGC4b$%so< zy1g_HPl*rpva=H|+DJ!vH=`k0Jjv7Ibg#QSN=Drj((zc|sJN33(?Lfa(S@GS9L?LK z?r_Xws)yCMC@RM;S*x(lud@BU7*IS=snEx zyt~X$-IQ!3b5rD-XY4b$Y_mIF<5A09jasFTo7sqQwzGSDNAf-O_+UXn+LDxgvY}@n zjd6{SOLBSY=1eUZWTm@08NuP%fF{t!bc5LdEg7E)ed{L78%EPOlNWDvJL2XJBI)*=z5%b@X_S~?i$yE$RkRPyjf)pAxA>&~8I36a0Zo6JyC)T9n-PHk4 zG9J+5edooK-8?#qC>c|YW;0f&2kEBfo`!(-!{b@JmeBbDBg~~4L&Vb#4xs!fKE@dd z!6+*GH6M>A1LL4^nve5&d(`5QlQ%m?8E3<8o4J(sB;fQasME`~97)usl7oyc%gp)} z8w)ickUAfaChai-W4=ON8ym?a>s!1OIPp*>slq^wv#eLIOb^C)xjTTF7AMPFH>CZ% zas6b_>7}E_q6uimZrB|*jwUps?PhDv+~}pzjO%(oO~<)WL-cj>1o6ah&A{YPNDsDA zJh{dBLA;Xn*7NyjVa0};_ag@-u8rPI@0c0e>&#aMv7JqN9Z7dRe|61?WjcevgV1q% zf-FJYm#pWf!>BtR+u)SKM01Gdb9C67tkzIob9&5hvdmOkV?Gr!He-MgJ6%m>$P*^! z8Ofz&l;1YHHXaX~jYivju4kiFdz2=OrlEDVyo>S?F z`CQA9>Otqjk-PHpa4is_xaa5y?JsQ@&)%|}j^-Q}?^+3mYZH#@U~z{o;{Xr#r*wuE zn|zVj%A8E?j=5x$j@8Clt&`Ox2JL_PWOVuo*>Y8<>&Hkk@v+R%O!SuEokU*BdJz zi)&5|l3~7v?ejcmn;EhFv6>_+RzOL&gQIbGC25c2g~JH{8)VV&TRdF39qncp)w-R$ zW>!yLP8ppvwA%U=J7Da6p2pp~6!!Eo8UUWNY~6J)pYea>v!EekVJr_lL_ADjK$Z#34@*@W0e9xXZ-V;Qq?PXtc@zd@Yf72rnj5O5E%fOFs$@JV6-p8y{K?+2S;17u(UJQaMCn82sN>p=$` z0<+*@AOhba9`Js!0UF?s@c-WdI-m_sfkkix90q?%9N-h+R}Gb(34T&z3#YG5hN|Sb(q8BY}#9_$2iEQk}V`KXk(aSQZUWq_L{78dt$3u6;rK> zC8}A4r17p8<}OB#v%^9vY=$en3so!*mSO& zGmY~t+`4miUA5D}ZQiulO1|D5>JmaityXk~A*ji+jA!=NpHbg8bDIv^xaeOy#Q?}@ zhn`(*XA3i63vFWUs#UZyw4<9hqb|nzX1eSwMwAjajN*-M;@w0?ARNUbi^p2^_=FpF zZ{uRcGiqvEAley!{eu|64R^^8zxY|9y>o;yXqjA-UGc?rBLdX30)Berts}-mzF0Zl z=8^5e{y{yhbYY!6@oZ43v8huuZqUQs&YaY56qzA1Z=-g~0wiW)6FtGiR=mf~JEI2w zbBxEMI7U6(w`|dp6t6ghoKSoC^Mg2JJY=TRPZvp19DxWy5eu;Z$QG2CU@FS3cK5YV z%|zE){>`)MKs<*oGSrex$VRcv87$-9bpZN+Etf#O#Ej_xo#%Di2f z{o^sZh|Y}?bI0>x+J<6n#9e}4mTRiJ)*FGDs1rs#Z|W4rpL<(93br+zmJsd9<+ykj(?dq6o6lr|=F8sICHSnBYv? zEY3F*!bFURqI^MglwqTKv`uK#rwyrTL_%kOP4pi=gPO_&M?Q zY4EsRqy$dy=|1ZsGP#HswjBH87p`R^xlXcRB+HSFUP#4plLooDa%a=~2JxNE1OvGV zJKlAdEUP8~9IUO{Fd;BE<32S0UOuB!7Ak#FQ#7mTuKpH&?_^~Kn`cDCGoDfs@tHF? z*GvNokD9|Rs?O}IUo8Q78`qaVqHjFoee=!D&3fAI5z|T~{Tsb>HR(0*ziS)&_SVUxuN}=$JG$`vV@K?Y3N7OqJKQwk z%w#C3JQK}!(v>i*q5Bg0`|5jZr;Z;ydFsX!OEZP=M*tJ>p3=gizhekD&^XeCY7$SG zFk=2QllT)=prCx^3D^&&))a0vKIq43FHcc%BRnrp!R21ITEBL@*6R{?v5vMg420U* zjF?el|K5Fj8~gS*_Ff%EMZT)@V+t1`$CrQc1Ye*=G`rUA4`D0wbOyN%!JM0CISRSd zX2G0x=0h@we8ij?F2RaIkq55{m~LVxBmbR)4MWtKtf!Jd`2nNvi&ZM%X|BM;*C1)EqDTDyjU!&#~~7#`)E3JtY($O>6${h z_lXRIOHGy}d0z6o1KJTStVG^N@oNllesd1I}h)rb~XvFEU zBhq${SjM(DVCxam(h>n~+%*zLDa1sANN-`FkN_jH6wF7oS3c>E&DE5RCX-}1M3lMk zdyp*z7oBOM^WK2k!qV$w82HA;EpBW@3z7mnH5mQ$n!cOu=9;3bF4PQW4ml{;OwB-^ z#&Nx2be|fm+jU4IcA+aO$NZlc9PA66D-ZdiNqQ*RXd+kdVMygzdD!l5I_ zPaW}}@WUID1sN*8JmmkXuVQ7d4xE2~-^r-j3ZHu*uX&m1*wa4}v{yT~SnOl7!I zwj306k~rQCsdpEouDi!viKT(mQhI{Zm3K@%U#G*^YcR2wn$a0Y1Si&WqSka{ZL(|v z9b1{sr)P@MdrtTn7F@0PF+*9Y6i8paqoYzY$ymo(}#5pZ!-s25R8H;iKOR{tkAXunZQ!Gr?}~ zXz(a-4}SXxzx|25g*j+qx*eEV>Mw{NLM~B1v z&%2YnZ*kLpe-k%^h_0LCYRvR3*OBw)c@!<-zR$;|woRfHkE!YhpQy;U`s#(WNqCg0 z1Pcz--Hh}kvp$Gs(;`l1&m-J-i{{@s5q$rWvd=*lwF-J{PyJisf8H6BdDe{NbRwNwrz{0_rN;YkJj%iO2zU+v_}iGOT8D8BR-L z$I4jsCRW;je8crqqzG+G__>R-qVZDt%B8_A6xEZ*g`0$CP*T~wOl2g5ZF-TaI@AlKFl z4h8Y~4?-fVH7;o;5mqli~oXRm5D#97TsWwBR>?N*2l_epbx6Udy z8;pXbOr!hFC_P8`Luy~dD<~^|CF@q?@6Ec!(vriHtHyxJkzqER^pa6DB6&`+gpp%j zVXEWE)B4$*{vSVi%*RATAq7@6%!Acx3r8t-WFD;~{cf*IFvlFAlSpaqKXmfo@%h*# z<@IqbO}TarW}i*gsoB~(*`1rW?EUUMss?CWLyPjAJw=$U=pl5X#Ad+-f9aN+9*dRC znk%DfHJv-yiiv`$rF9JVfg%_pU@%r`95z{{NM<>OeXz=X#M5L)3(hq!vlY3Rjtp?0 zkqQLXoWSo&^Le$i0=0 z>+y}*rdl~)j*?;!aBJ0tH${WXS7@tJPx&m1Wi6H`tMUX126fviK?~!m+fn%y5-m@< zy-qz|WZ)#CDwL_H#DYe~QFwV`hM6)zrKv?j`)M*z=7W*&N;LC2@l+Vl=ehI_MIT;U zAzZ{G1@H5iY(@1tI%?u9X^*h~G$Gxw$#Rk0tw%^l3cuX#@oXs1X^rSh?JUvPtr|6m zL&{`Y1WeUIX+JNTDI(=a%Q#V?{0wJ!sxcc}@P`yxx#a)havTr>>kR>!+ zl(e>biD{H2->_DXk60c6(r%e9wmg=iji|@0UvH-!bTEymV`x#WsK6=i;aZo9kb$aO zX7<$&)UT>!r(EzvQwF9bPO(LsE3rqP%?ge3Y~@O7y_^@%8<$xw&F0UHQ8cDPVrD)% zjmzpKkxKy1$1{PRGkWkSrIJWBrJ7-`^D|)-c*YDy`x=fj9A?4-j~6029cAK(;!LEi z!gS_i7ZQkuS@at-nd~GY&X>J73zN(8mJ|6fkvb5$DZqqsR;>~|3W66JanrS{ujYN2hRuJ5AMUie+9S!TmtUFum4%_FTic!2f%OQ)4v*Y z!5X*{{8#+?Ujn}f-U{9VUId;5K8cV26W}@ECh!#SRebzUgZsd%!9N6w0X!GXgG<4q z!AJ4;mB)V__;K)K;0UOJdGL3@$A|&E5Bw~+8!UoZ@O5GU9|x}juLM5^eiS?td=3Bq zmp~u%z>Pq)|4#*v1*-l3^uuU7cr|zlNWcNGAAI=|;vL|V;AKGZfdo7Q%mKv*{tVy$ zv*5SDuYz}j4D1K{zzldiP@F(x^XZGn1^meI>!jH+5|RW)Ykj<>saT{jXUUXIFRR)w zqJ7i=_sEHzR^8?m#&JGWg&#%~#wXL2NE)H?&0gZqU^O^3(BfKxi&Xn*D(9$oYEYXY zXO;Ansvs>#iDY<$Yb!#L@a<%P3BrVPSJ6N$5lsf0*Unc*QjebXL-BPMw zQj{v_vv0$!_bWo|&w}6$yKF9s%1lS-mRBAvcc#^hu19r`DPHBRRl8WKY6I)eLc*5g z=RC4Wl#tOS`K92l3OtY=E)+jXsUaRvf7zR)@$8`6n?ru$lB4|UrcfkyBvN4vBett9 zSSDqP{p*NDer1S#F}s~;db6YKY!hFtostT2C182n#1|h^@Yz(#6de~gm$QzlPb+l7 z{ghZ&)+6`jf4g>zh>dV?>T;-ETpe2mtD)FJ$Oaa{P5G8#MCtIO4z~t&DBd}B2dwabX%blvRht+W3Dp%_hNO~&Xl~Df zU_!!lr2`%Z|8EAr-HEHF^Bvm}(nP=9c?%$r7RtXyS@~La!~M zu~b5=84fGLJ3$J_lj;wxz6J}T;8l|t$%r)_DqXgTU!BZ;(k{B3cQZa!_@q=5=kcC> z^YNbj;DDw8oZn7uMU48fM_Fd~bA;*Nv`Rf(k*6J&lRYbuFlT?^Q(BcG*UCm>LN(*r zv+ggPy^W-nEo*I4EgRLWl!MB`k>^zvL@5)&9ez@W8C{bdH+5xb&8!mo6w$UR0mpc$ zWh>>iGeHQ&@!_mZiN>t1d!T2An2DvyBRyZVuWiM9G2Qmsfq{*Ns~RpATp&i3qK&V< z`s$}&E4M(^9`U~Y*VOm&-#!XY2l%5sLN&#)rd?}~43eKKP@Iy-wG`Qc;m&YNZw%*} z(UE~wTM>Dh46noNQDFy;qd7NDyu)rvoxxKUvGc_&R)4NE_F|<<7>PM)D`ckyDymqL zA6Dm9XP^1P2nBdlHL$-$X^12NFH{vVx=#7fgeMFsmBIt1_C}svX1t zt*qLemO0+@z4h<7b!M@ftP{(Sytvh=faxpDo^ATJY^J02JE}zUkExUiG~1(O(+Q~Q z2l+=>NE>&YB|7GT%2&ea*92nDrsvW&l{dCpElxTcqt+#hO#RvPtUl~kp;uks`K8EZ4(H7>*2%eK8rmBByzt5{vTM!&I z7K_j!StbPJNM&Xt)wGe;lD@(bu~mmjeDj28prU6QQty)*DO{w{nG^nPSjYr>G=w|{ zK2!DfviNuPnpVZ~Cncg;sCl72i2|dvudKa`bQU2YxlLlsqE{A*K4o_iyH&lv zW{*mx29%gQ-m*(DjA{9@u=torpsAPzhI?+INlY3S}#zxrv;o!i;B8*idYv`+(iT>(Mei&Rz^15N+bn^tnNyraX$E# zbmaD9QZyFyD+X)J`NW#N;uy$o^lQDeQSa|>MF)w0*ddi0xOF6O(UgCtDJAixGA-;f z!pr;{Ol&H{q}U0ugbYgDWt~c=%S`DJw14ny$aNOVa)}!Y<#?qP(lYkK<#akq#jHsR z#*!Mh0TN?X+UNMFl9!2S=|)#eIw@kb42st2X(g*$wBc?Row;yDp|-+#E^CRSYQ8G) zICf|#gN6FrXr~#A-Jx0KBKvA1Wi$*kc@Sn?&(1}J8596S?T2fj{>;9X!HTnQct9sxc~oZwDy4fx;00)85t15XE6gFnLe ze+NjwQ@|hK^K0$@ZQxdL5}W{!1OE}9|K;F%@J)RC4}gCFUJed{zrd&e5RjjLHTcRU z^b!0!@UMVs`Pab`&^rGa@Cfi}eE0jnE5K3kEodw{zxUhmG^%{NzQt8p}(2=;Em0MM{W{!u6FEXVvN>OG0Q#pOmRB!KK=I8E#xAo4Q%2^w=e8 zgM#UY6_{LtAT^$LYcIXTJXA_d&Ajdg63wsN_3G3qfbkZh&!U+*GNn2G%QOF?wPl6f%he?Sx==JV4iVZn_dnGQ&cLMJ==$XPEhN011dA zL(Aw&cwIAcL#7BVvD!q+(M%4hR4`r}VKH~QT0W_?_iBbA%TZ;pblMR@nK#>`K(8#< z>p9o|=)vd4N}{?D@XVMpdgw2$M+`yn6dVrprr9Z2Ur0Gs2_XoNOL;SvgZ5fK@dkzo zzK{;36ILEvk7r3oM-~Wn&W7wFnb!U89rg+ozwIHQ^;pgYAq!`3q$8501WBznmqq(H3%bsy&+}ad>1?Csk3YH7O3GPj>rY+jTT1etuKE;lqhMTMMcDxmvR~A{r zP%)0%?B*2rm7HL6(9E(a+B!3gXrpp2I7o1To5DuI&_aS@Oi%4ofns%A4n)v{B?*`* zykGoa#4xATYP2n4L>krUMyuOqVFFLVe6j{gYa<8iy~pqi^T(cdZmgXd=(1T#@H%VjT}K^>-78#w>8%2% zj+&Zs^ub9l3NFHOW~D6fZ;BVvq@wp6wdoli!txY|&vqKxQAHjHbFQ1=k(fNgme5iM zWsBx@hE&Fs9WSS&%a#gxk6M<_4COZ0r!H5EcyR`U+-%%H*v#5#AlF!SgNB!>wp4tPfU#=zK>I{0y2O`3wU0zuJ4E*FPlIAH- z@&0;E3{>kA1HWhmP`E=EQ?)`LEa+~H#-eu%;Ku2y+kyn^l>&w&FWywwn;JCFUdei8 zU(oL#OW??r*!a)4%gSpgskb+j@^5>>T;wOXt$%k`r<`(Eh9Bsl(L+SAHvt=l-YXAgW(f8|2K83pJUAg*^dnYf{J=B}ntAT5V!$zF5UT~&@boJ4Q!5WJJEgAF zpbN=aizxNgD^mW>&{-xDwypvblexN4`au-Y-zW(28^Mk1_|wbt>JkAdb?f1EUe?=P zQ{iG8XYUm#f=g<+UFn^|N55JJaH z%$}w&^I9YA+HMc=5*7mbcoWi){Ugv%p2f; zcV-dmZJ)iQsz5kILN#S=AI4Zc_$y?SSXoC%R8HgUvv|@zpRpLWY8lwK%JIgdEV;e3f723?qyoW=pHW zxQG_`TC1c-9gI|$W}CK3kypUXEEcFa|C;A3;4WcZz0@9Qarg7glta#7pQb6X+o4u0 zv7HihsG-!tXgT{mzIyN8y#!23)u7nFZxM5a%%TO~Y~FH}Oc&HB1$}riSUg7BHRd|o zab#{$UinI;GsAKXACA;Wl;q_x=EZ0Ss~Qcoc-@gD{1-^#={vq%>&TloB+0j~ki1y2T-gU{ptzZtvGkDpary6;B&+SUIT6dmx7O8LLC7(2~L2=fsf<+e+=9U-U8kPUIw~g5BNUt zIG}a@f5tdyEZ(B=updb`+fajv)mTINHqH^>YG^%8Sg^Q89xEGe`&e^1p+hz)B`Bt^ zj`UnFsk}=~uUajRyXBnImzSw3-ACTs)^H~5jJKB7{rlZ7Hq^j)ra)*v|Eq|PC7pNO z71>?kpF5{wpq#5NK!x{hb2Sdv>-C5fzB93>8>Ba|wsJ;o)*>WxAh1!@QyD2!;YXJ% zw7Gqw(W(-fW4iRw*KEMnw7KhFGavga&t}W20;MR9*^210a#U;w)$_))BAvG0aW<~R zy z9@T-gaCsPh`PP&^khfI1AbMsam0sQzA1yjqZq+1KaV6vG0^6mz#S|4~F%@XDJv!>q z39Xl%P(kP`-9{FbK&4j;YD#_N69h_bq3~oru;Yl`-&-k)U47VjHOsV-rCW+&(FVfz zE~KS0?usF;&`Yq|HQcUH+iW`5BQr}NZV`(p7B5I?%h&L;M`gY?gEYgOuWDQ-gP`RU z|LL^H%dt%}RJ|BS8v6?I1w;@pnI=~a4+fzXZ@3NT>P})^X-WOKrsoG!q2^X)%Pa4w zMuGx-&Q|ZWWG2(VvOW10#t$B?O?*vGTmMUS07;q9)4&4M99B5K9^ayfU}3glh?tE5 z6Iwx3+_H``^@FiCPf>-OP23BnD5R24f$Rv+hgV8|mft0-s|=h-@QEIdbkka6$ZnFR zHDvGd)O5h&TenQV;aBsA&z?YG**X_hDTQ1nt7c}cy5F;0tx5eF?q+?zsJZQ$;QX3H zdcpHSLCMg%V-~n1zT>0JVYWcu2YARp6jV~JmWpz*91>Ll!)+AV9#HRxs) zd$yEO?$+zFWYo5C=$)lZ?L2$abd5qULwRr(xq%%kl7hC`IuE@6o82zKB-eeutyOIq zMNVw&KxKxN{YkUI*{y*#QbB|l8{}jbl-0t27+o(`Jc1Y7K0Zd&l2vWaFqaTHK9XD5 z#38Yb5Gf61laNl~dCw-KVZHibHl9ZTO~){PwqM4XKwT>o{b+lQc$T8V zGq>Ny&H|HkZkwfKS2txl@TDZV238j~> z#GXx|(mULCy@IpkZj?l`XspG>4xU(wXAjywIT$BPn0<2*YTZ_$E7a`_6c>fJ*3dw+ z^>|lxZ>})&lBhC;UB#1G&xHoGZG@%*dF-AnMdT|IoeWbwj)^p0NB9gX=#MgES4||E z%H@inZJp6}i`y-eC_ZMLJs=RgMa(c18}jZy2j7bcZ{+IH@?N#eIU zmjN3oE5;OMqA3>R%VV^LOryzyGS4=QkcP*UT+l(5;BYw(UqoMAF6J(mG-;CB-L}cH zi3{UNn*g|(HkoLPIVOHLHqD-*(uDOLI4;&Cnmg={SI|*3$$G;+DkfsolX41KZq9A# zPRUt>g(|gHMB)AFAUesJO?ofyv-BwsEg)4kx}$6$Rb?e>N*GzKs9VamvSJZG(tk9J zSeT3h=!ZAWQ^H)c*L7(62+w#mYBjA|j0*K(rlQbdPx)4p#dLb!51Ycpr*@g&mUJTi zx-5`ZWNpQAQjwy!Y%r3S>6gVmtq>^WFSR0DMM#!)O>A*yaP*7hv6}44!Wt!`R7$%` zYC#T6mWT=-EGgTE#Uy#q>UvPwQ2l9HI(PsHF(xr*^;ec77)8`shj6JJ4=uN|t1AzX zM`mGEJYhCgR3#d#T!S<=)^Uk!m)GY#ipWpN-|}g4c*IQR#YP8JiA&}%?ITLv=(~|~ zQApnym{$rVgvlpGHeD;#VIz2!HdjBSk!Gt6 zIEiZAK>`XPW+p#PpD+zsb3kCYO5e@J zFuo_8s2m931N7RvoG1jvD}#xKw`vjlatvo<6yZAwL08}D9FFRYQX?h8O4@P2Io(b) zGk;A`hCrQUbP~-LBA^b_S;)^KXlyZ1%9m=xX#mUn;{`1#K16!Uc3yC-Zqvl-M5hD- z@39AHTLz6+m4Fi!>Ju2)7EP?!4;S4AXtTTCZ7lkykUtu4k&7mcshJhI$yuq|^kv1a z4Kx_bfon#R2Vwz7UFK(~06Lt=UP%I2}R16J|5U*;>1)ETqhwQt_(tcPGhOlq4VHOs#Xi(q3gi%`)5>SWY!pG| zg0fOtgltuGP3;v~B642r?8~=BP-bd!+=ZF(p%aZeor542bJm}TEZyg8w%{H(M0r4} zz`5U9&j<)Nr`_K6B{w^%kAASKLNAJW$yRixCwtebk=*M`^SHv^n<@FSv$7*bw#${> zK#|N*N&?(OCJkxP^XP(W$T5K#h59CezggMwl{hfnve@mMkLjPDK3yaNTxXya4+`D& zbd$!j1yeU^<+__)XULSS6m=Wfklp-{4Ji4BPBd6PU$_r_oC?1ABxj$H=**Dmkh0mj zFZfuStYW1M2<2+;vasowyNHLtvPA`rH;`;7BVV0>h<5}=l=&>e1?jte7qJ@G&qXj6 zVF+TON<=g9N&n~zF|@ajsNlt;Upm1W96VI|!Fd)~pvv8*qXXrs?-8~d+{_g>rF ze{J)?UN)aUdLmv?Mei~8Z0WMw<6^1|)=75lP`ls!WJzxKWVHK6?eoOf(h_kRNX zP4G$l{*Qor!OwxSU;$hSeh;5twg3MH+y$NoRNMbq@T1^XunYV-KL2~cYk~6p*MO&k zKgRF>1@L0Po_~J`?gbwP*51Mi-}z&Q0y`Agp}-CWb||nz zf&afKpgj939z4#6LK3#|oH+smZEE4Yfh%P6Q#4}Reeun0W%Z4!j?G$k{(_n7Z24&Z zZV{>V>)z};9xNZ^*Wf5Zqi?#!5LZo+ImO!M6d%P$ZU{HJ9R=MYqUel!jXV68^1Nsf zr5e@!n<@5h9A@u_6E9#(YvMP2O&#Cupzr!rduTakvhDt-+E14$vdD>ofZCNqt!2|| zg>al#c;<$aHxv%$>R>WlbrCr(l#}CXY@&ide6h@Fw?;!ArYNS89=G41NI;lEfmWO@ zzgb*m6%=6t3WwcxffkqJ9g6U#H6r`AUdi|Hen2>EhqU z7|xH94@K+cmv7(bRE*OwEB+Ktq*@XRSjw6k+i|rPJiYq4K7|1L*2_A}zrv%ZzjdN< zu{W9T{ovhLOxf8XHOK3F+7n-}0uVVk+s}E2HdSL(FtFEc-yQPN~oMgUy5L;V$}&Sg}?n8 YX#{1+GM8Pe!~wU5?HJ()RUXm*0fMBN5dZ)H literal 0 HcmV?d00001 diff --git a/src/synaptic/core.clj b/src/synaptic/core.clj index b4dd330..1e919a2 100644 --- a/src/synaptic/core.clj +++ b/src/synaptic/core.clj @@ -39,6 +39,9 @@ (def load-training-set-header d/load-training-set-header) (alter-meta! #'load-training-set-header merge (select-keys (meta #'d/load-training-set-header) [:doc :arglists])) +(def dataset d/dataset) +(alter-meta! #'dataset merge + (select-keys (meta #'d/dataset) [:doc :arglists])) (def training-set d/training-set) (alter-meta! #'training-set merge (select-keys (meta #'d/training-set) [:doc :arglists])) diff --git a/src/synaptic/net.clj b/src/synaptic/net.clj index 7e8d42b..2b94e53 100644 --- a/src/synaptic/net.clj +++ b/src/synaptic/net.clj @@ -111,6 +111,11 @@ [zs] (m/map (fn [z] (if (>= z 0) 1 -1)) zs)) +(defn relu + "Rectified linear unit." + [zs] + (m/map #(if (<= % 0.) 0. %) zs)) + (defn sigmoid "Sigmoid activation function. Computed as 1 / (1 + e^(-z))." @@ -259,8 +264,8 @@ "Estimate labels for a given data set, by computing network output for each sample of the data set, and returns appropriately transformed result - or its index if labels are not defined." - [^Net nn dset] - (let [x (if (contains? dset :x) (:x dset) dset) + [^Net nn ^DataSet dset] + (let [x (:x dset) y (m/dense (:a (last (net-activities nn x)))) label-size (count (first y)) lbtranslator (-> nn :arch :labeltranslator)] diff --git a/test/synaptic/net_test.clj b/test/synaptic/net_test.clj index 6c1c3e3..f2aee1b 100644 --- a/test/synaptic/net_test.clj +++ b/test/synaptic/net_test.clj @@ -39,6 +39,9 @@ (testing "sigmoid" (is (m-quasi-equal? [[0.25 0.2]] (sigmoid (m/matrix [[(Math/log 1/3) (Math/log 1/4)]]))))) + (testing "relu" + (is (m-quasi-equal? [[0. 0.2]] + (relu (m/matrix [[-10. 0.2]]))))) (testing "hyperbolic-tangent" (is (m-quasi-equal? [[0.7615942 0.9640276 0.9950548]] (hyperbolic-tangent (m/matrix [[1 2 3]]))))) From 37629a719c4daece2d9d54e5177b8b2b9262ad27 Mon Sep 17 00:00:00 2001 From: keorn Date: Tue, 2 Feb 2016 13:18:35 +0000 Subject: [PATCH 7/7] add relu deriv, untested --- src/synaptic/training.clj | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/synaptic/training.clj b/src/synaptic/training.clj index cefbfc1..97d9665 100644 --- a/src/synaptic/training.clj +++ b/src/synaptic/training.clj @@ -111,7 +111,8 @@ (case actkind :softmax (fn [ys] (m/mult ys (m/- 1.0 ys))) :sigmoid (fn [ys] (m/mult ys (m/- 1.0 ys))) - :hyperbolic-tangent (fn [ys] (m/- 1.0 (m/mult ys ys))))) + :hyperbolic-tangent (fn [ys] (m/- 1.0 (m/mult ys ys))) + :relu (fn [ys] (m/map #(if (<= % 0) 0. 1.))))) (defn output-layer-error-deriv "Returns the function to compute the error derivative of the output layer