Add a CrossEntropyError class.#93
Conversation
This PR gives us another way to evaluate how well predictions do against the actual known distribution. The iris example has been ported to demonstrated this method in practice. There is also a small refactoring of the local trainer's validate method, and some small refactors of other error classes.
|
Review by @tixxit, @avibryant, and/or @johnynek. |
|
There are some problems here -- please wait to merge until I fix them (Travis should notice them too). |
| } | ||
|
|
||
| def tee[A](fn: ((TypedPipe[Instance[K, V, T]], Sampler[K], Iterable[(Int, Tree[K, V, T])])) => Execution[A]): Trainer[K, V, T] = { | ||
| def tee(fn: ((TypedPipe[Instance[K, V, T]], Sampler[K], Iterable[(Int, Tree[K, V, T])])) => Execution[_]): Trainer[K, V, T] = { |
There was a problem hiding this comment.
Just in the interests of increasing my Scala knowledge: what difference does this make?
We talked about removing this, but I balked at fixing all the code that would have to be adjusted.
| tree.targetFor(features) | ||
| }.toVector | ||
|
|
||
| error.create(instance.target, voter.combine(predictions)) |
There was a problem hiding this comment.
I think this is going to give you the wrong answer when predictions is empty (eg because this instance is not in the validation set). Who knows what voter.combine(Vector.empty) is going to produce, but it's totally possible (even, likely) that it will produce a non-zero error, which means we're accumulating all kinds of extra error for stuff that should have been filtered out.
There was a problem hiding this comment.
I think we had convinced ourselves that there wouldn't be error in this case. But it's easy to add that special-case back in if need be.
There was a problem hiding this comment.
I don't see how you can know that? There's no obvious law that Voter and Error need to conform to which would lead to
error.create(t, voter.combine(Vector.empty[T]]) == error.monoid.zero
| val totalNormalized: Map[String, Double] = CrossEntropyError.normalize(totalData) | ||
| val totalEntropy: Double = CrossEntropyError.entropy(totalNormalized) | ||
|
|
||
| def relativeEntropy(xn: (Double, Long)): Double = { |
There was a problem hiding this comment.
this should really be called normalizedInformation or something (how much of the total mutual information have we learned).
After 10 runs we got to like 60% or something as I recall.
There was a problem hiding this comment.
If relativeEntropy is the metric we really care about, can we build that into the CrossEntropyError? It seems like the only thing we need to keep track of to compute this is the total distribution of actuals, which would be easy to include in the error monoid.
|
Having other people extend Brushfire really makes me feel the lack of having written tests. This makes me feel bad but is good for improving the code. Here's a law that I think we want to apply for all error.semigroup.plus(error.create(a1, p), error.create(a2, p)) ==
error.create(Semigroup.plus(a1,a2),p)Please note, this does not hold for predictions, that is: error.semigroup.plus(error.create(a, p1), error.create(a, p2)) !=
error.create(a, Semigroup.plus(p1,p2)) |
| Instance(line, 0L, Map(cols.zip(values): _*), Map(label -> 1L)) | ||
| }.toIterable | ||
|
|
||
| val totalData: Map[String, Long] = Monoid.sum(trainingData.iterator.map(_.target)) |
There was a problem hiding this comment.
It seems wrong to me that we are using the full training set, in any way, in an error computation; I feel like we should only be basing the error computation on the validation set.
There was a problem hiding this comment.
yeah, that's fine. We just need an estimate of the "true" entropy. We should use the exact same set that we are measuring the error on below.
I guess we could fit it into the semigroup approach by aggregating the total label distribution as we go. (in that way, by the way), it seems like Error should really be a semigroup and a function (maybe a full aggregator).
There was a problem hiding this comment.
Agreed that we should do that (that's what I meant in https://github.com/stripe/brushfire/pull/93/files/6aa9b7fe7adaaf9dbbce7f79abd767e56c4f2ecd#r68490840). I take your point about Error needing a present function. Error[T,P,E] could equally be Aggregator[(T,P),E,E2] where E2 is the thing you actually care about. Or you could decide you just cared about having an Ordering[E] (which comes up a lot in practice) and which would do the transformation to Double or whatever internally.
This PR gives us another way to evaluate how well predictions do
against the actual known distribution. The iris example has been
ported to demonstrated this method in practice.
There is also a small refactoring of the local trainer's validate
method, and some small refactors of other error classes.