diff --git a/build.sbt b/build.sbt index 18006336..04ba44ea 100644 --- a/build.sbt +++ b/build.sbt @@ -88,7 +88,8 @@ lazy val saulCore = (project in file("saul-core")). settings( name := "saul", libraryDependencies ++= Seq( - "com.typesafe.play" % "play_2.11" % "2.4.3" + "com.typesafe.play" % "play_2.11" % "2.4.3", + "net.sf.meka" % "meka" % "1.9.0" ) ).enablePlugins(AutomateHeaderPlugin) diff --git a/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java b/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java new file mode 100644 index 00000000..51cddba8 --- /dev/null +++ b/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java @@ -0,0 +1,615 @@ +/** This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.saul.learn; + +import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessInputStream; +import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessOutputStream; +import edu.illinois.cs.cogcomp.lbjava.classify.*; +import edu.illinois.cs.cogcomp.lbjava.learn.Learner; +import mulan.classifier.MultiLabelLearner; +import mulan.classifier.transformation.BinaryRelevance; +import mulan.data.LabelsMetaDataImpl; +import mulan.data.MultiLabelInstances; +import weka.classifiers.bayes.NaiveBayes; +import weka.core.Attribute; +import weka.core.FastVector; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; + + +/** + * Translates Saul's internal problem representation into that which can be handled by WEKA + * learning algorithms. This translation involves storing all lbjavaInstances in memory so they can be + * passed to WEKA at one time. + *
+ *
+ * WEKA must be available on your CLASSPATH in order to use this class. WEKA source
+ * code and pre-compiled jar distributions are available at: http://www.cs.waikato.ac.nz/ml/weka/
+ *
+ *
+ * To use this class in a Java application, the following restrictions must be recognized: + *
weka.classifiers.bayes.NaiveBayes.
+ **/
+ protected MultiLabelLearner baseClassifier;
+ /**
+ * Stores a fresh instance of the WEKA classifier for the purposes of forgetting.
+ **/
+ protected MultiLabelLearner freshClassifier;
+ /**
+ * Information about the features this learner takes as input stored here.
+ **/
+ protected FastVector attributeInfo;
+ /**
+ * The main collection of weka Instance objects.
+ */
+ protected Instances wekaInstances;
+ /**
+ * A buffer for lbjava instances.
+ */
+ protected Listweka.classifiers.bayes.NaiveBayes. Attribute information must be provided before
+ * any learning can occur.
+ **/
+ public SaulMulanWrapper() {
+ this("");
+ }
+
+ /**
+ * Partial constructor; attribute information must be provided before any learning can occur.
+ *
+ * @param base The classifier to be used in this system.
+ **/
+ public SaulMulanWrapper(MultiLabelLearner base) {
+ this("", base);
+ }
+
+ /**
+ * Empty constructor. Instantiates this wrapper with the default learning algorithm:
+ * weka.classifiers.bayes.NaiveBayes. Attribute information must be provided before
+ * any learning can occur.
+ *
+ * @param n The name of the classifier.
+ **/
+ public SaulMulanWrapper(String n) {
+ this(n, defaultBaseClassifier);
+ }
+
+
+ /**
+ * Full Constructor.
+ *
+ * @param n The name of the classifier
+ * @param base The classifier to be used in this system.
+ * have.
+ **/
+ public SaulMulanWrapper(String n, MultiLabelLearner base) {
+ super(n);
+ baseClassifier = base;
+ freshClassifier = base;
+ lbjavaInstances = new ArrayList<>();
+ }
+
+
+ /**
+ * This learner's output type is "mixed%".
+ */
+ public String getOutputType() {
+ return "mixed%";
+ }
+
+
+ /**
+ * Sets the labeler.
+ *
+ * @param l A labeling classifier.
+ **/
+ public void setLabeler(Classifier l) {
+ super.setLabeler(l);
+ allowableValues = l == null ? null : l.allowableValues();
+ }
+
+
+ /**
+ * Returns the array of allowable values that a feature returned by this classifier may take.
+ *
+ * @return The allowable values of this learner's labeler, or an array of length zero if the
+ * labeler has not yet been established or does not specify allowable values.
+ **/
+ public String[] allowableValues() {
+ if (allowableValues == null)
+ return new String[0];
+ return allowableValues;
+ }
+
+
+ /**
+ * Since WEKA classifiers cannot learn online, this method causes no actual learning to occur,
+ * it simply creates an object from this example and adds it to a set of
+ * lbjavaInstances from which the classifier will be built once {@link #doneLearning()} is called.
+ **/
+ public void learn(int[] exampleFeatures, double[] exampleValues, int[] exampleLabels,
+ double[] labelValues) {
+ checkIfCanTrain();
+ lbjavaInstances.add(new LBJavaInstance(exampleFeatures, exampleValues, exampleLabels, labelValues));
+ }
+
+
+ /**
+ * This method makes one or more decisions about a single object, returning those decisions as
+ * Features in a vector.
+ *
+ * @param exampleFeatures The example's array of feature indices.
+ * @param exampleValues The example's array of feature values.
+ * @return A feature vector with a single feature containing the prediction for this example.
+ **/
+ public FeatureVector classify(int[] exampleFeatures, double[] exampleValues) {
+ if (!trained) {
+ System.err.println("WekaWrapper: Error - Cannot make a classification with an "
+ + "untrained classifier.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ /*
+ * Assuming that the first Attribute in our attributeInfo vector is the class attribute,
+ * decide which case we are in
+ */
+ Attribute classAtt = (Attribute) attributeInfo.elementAt(0);
+
+ if (classAtt.isNominal() || classAtt.isString()) {
+ double[] dist = getDistribution(exampleFeatures, exampleValues);
+ int best = 0;
+ for (int i = 1; i < dist.length; ++i)
+ if (dist[i] > dist[best])
+ best = i;
+
+ Feature label = labelLexicon.lookupKey(best);
+ if (label == null)
+ return new FeatureVector();
+ String value = label.getStringValue();
+
+ return new FeatureVector(new DiscretePrimitiveStringFeature(containingPackage, name,
+ "", value, valueIndexOf(value), (short) allowableValues().length));
+ } else if (classAtt.isNumeric()) {
+ return new FeatureVector(new RealPrimitiveStringFeature(containingPackage, name, "",
+ getDistribution(exampleFeatures, exampleValues)[0]));
+ } else {
+ System.err.println("WekaWrapper: Error - illegal class type.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ return new FeatureVector();
+ }
+
+ public String discreteValue(int[] f, double[] v) {
+ return classify(f, v).discreteValueArray()[0];
+ }
+
+ public double realValue(int[] f, double[] v) {
+ return classify(f, v).realValueArray()[0];
+ }
+
+ /**
+ * Returns a discrete distribution of the classifier's prediction values.
+ *
+ * @param exampleFeatures The example's array of feature indices.
+ * @param exampleValues The example's array of feature values.
+ **/
+ protected double[] getDistribution(int[] exampleFeatures, double[] exampleValues) {
+ if (!trained) {
+ System.err.println("WekaWrapper: Error - Cannot make a classification with an "
+ + "untrained classifier.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ Instance inQuestion =
+ makeInstance(new LBJavaInstance(exampleFeatures, exampleValues, new int[0], new double[0]));
+
+ /*
+ * For Numerical class values, this will return an array of size 1, containing the class
+ * prediction. For Nominal classes, an array of size equal to that of the class list,
+ * representing probabilities. For String classes, ?
+ */
+ double[] dist = null;
+ try {
+ dist = baseClassifier.makePrediction(inQuestion).getConfidences();
+ } catch (Exception e) {
+ System.err.println("WekaWrapper: Error while computing distribution.");
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ if (dist.length == 0) {
+ System.err.println("WekaWrapper: Error - The base classifier returned an empty "
+ + "probability distribution when attempting to classify an " + "example.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ return dist;
+ }
+
+
+ /**
+ * Destroys the learned version of the WEKA classifier and empties the {@link #wekaInstances}
+ * collection of wekaInstances.
+ **/
+ public void forget() {
+ super.forget();
+
+ try {
+ baseClassifier = freshClassifier.makeCopy();
+ } catch (Exception e) {
+ System.err.println("LBJava ERROR: WekaWrapper.forget: Can't copy classifier:");
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ lbjavaInstances = new ArrayList<>();
+ wekaInstances = new Instances(name, attributeInfo, 0);
+ wekaInstances.setClassIndex(0);
+ trained = false;
+ }
+
+
+ private void initializeAttributes() {
+ attributeInfo = new FastVector(lexicon.size() + 1);
+ /*
+ * Here, we assume that if either the labels FeatureVector is empty of features, or is null,
+ * then this example is to be considered unlabeled.
+ */
+ if (labelLexicon.size() < 1) {
+ System.err.println("WekaWrapper: Error - Weka Instances may only take a single class "
+ + "value, ");
+ new Exception().printStackTrace();
+ System.exit(1);
+ } else {
+ Feature label = labelLexicon.lookupKey(0);
+ if (!label.isDiscrete()) {
+ Attribute a = new Attribute(label.getStringIdentifier());
+ attributeInfo.addElement(a);
+ } else {
+ FastVector valueVector = new FastVector(labelLexicon.size());
+ for (int v = 0; v < labelLexicon.size(); v++)
+ valueVector.addElement(labelLexicon.lookupKey(v).getStringValue());
+ Attribute a = new Attribute(label.getGeneratingClassifier(), valueVector);
+ attributeInfo.addElement(a);
+ }
+ }
+ /*
+ * Construct weka attribute for each lexicon entry.
+ * If entry is discrete use a binary attribute
+ * If it is real, use a numerical attribute
+ */
+ FastVector binaryValues = new FastVector(2);
+ binaryValues.addElement("0");
+ binaryValues.addElement("1");
+ for (int featureIndex = 0; featureIndex < lexicon.size(); ++featureIndex) {
+ Feature f = lexicon.lookupKey(featureIndex);
+ Attribute a = f.isDiscrete() ?
+ new Attribute(f.toString(), binaryValues) :
+ new Attribute(f.toString());
+
+ attributeInfo.addElement(a);
+ }
+
+ // The first attribute is the label
+ wekaInstances = new Instances(name, attributeInfo, 0);
+ wekaInstances.setClassIndex(0);
+ }
+
+
+ /**
+ * Creates a WEKA Instance object out of a {@link FeatureVector}.
+ **/
+ private Instance makeInstance(LBJavaInstance instance) {
+
+ // Initialize an Instance object
+ Instance inst = new Instance(attributeInfo.size());
+
+ // Acknowledge that this instance will be a member of our dataset 'wekaInstances'
+ inst.setDataset(wekaInstances);
+
+ // set all nominal feature values to 0, which means those features are not used in this example
+ for(int i=1; i< attributeInfo.size();i++)
+ if(inst.attribute(i).isNominal())
+ inst.setValue(i, "0");
+
+ // Assign values for its attributes
+ /*
+ * Since we are iterating through this example's feature list, which does not contain the
+ * label feature (the label feature is the first in the 'attribute' list), we set attIndex
+ * to at exampleFeatures[featureIndices] + 1, while we start featureIndices at 0.
+ */
+ for (int featureIndex = 0; featureIndex < instance.featureIndices.length; ++featureIndex) {
+ int attIndex = instance.featureIndices[featureIndex] + 1;
+ Feature f = lexicon.lookupKey(instance.featureIndices[featureIndex]);
+
+ // if the feature does not exist, do nothing. this may occur in test set.
+ if (f == null)
+ continue;
+ Attribute att = (Attribute) attributeInfo.elementAt(attIndex);
+
+ // make sure the feature and the attribute match
+ if (!(att.name().equals(f.toString()))) {
+ System.err.println("WekaWrapper: Error - makeInstance encountered a misaligned "
+ + "attribute-feature pair.");
+ System.err.println(" " + att.name() + " and " + f.toString()
+ + " should have been identical.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+ if (f.isDiscrete())
+ inst.setValue(attIndex, "1"); // this feature is used in this example so we set it to "1"
+ else
+ inst.setValue(attIndex, instance.featureValues[featureIndex]);
+
+ }
+
+ /*
+ * Here, we assume that if either the labels FeatureVector is empty of features, or is null,
+ * then this example is to be considered unlabeled.
+ */
+ if (instance.labelIndices.length == 0) {
+ inst.setClassMissing();
+ } else if (instance.labelIndices.length > 1) {
+ System.err.println("WekaWrapper: Error - Weka Instances may only take a single class "
+ + "value, ");
+ new Exception().printStackTrace();
+ System.exit(1);
+ } else {
+ Feature label = labelLexicon.lookupKey(instance.labelIndices[0]);
+
+ // make sure the label feature matches the n 0'th attribute
+ if (!(label.getGeneratingClassifier().equals(((Attribute) attributeInfo.elementAt(0))
+ .name()))) {
+ System.err.println("WekaWrapper: Error - makeInstance found the wrong label name.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ if (!label.isDiscrete())
+ inst.setValue(0, instance.labelValues[0]);
+ else
+ inst.setValue(0, label.getStringValue());
+ }
+
+ return inst;
+ }
+
+
+ /**
+ * Produces a set of scores indicating the degree to which each possible discrete classification
+ * value is associated with the given example object.
+ **/
+ public ScoreSet scores(int[] exampleFeatures, double[] exampleValues) {
+ double[] dist = getDistribution(exampleFeatures, exampleValues);
+
+ /*
+ * Assuming that the first Attribute in our attributeInfo vector is the class attribute,
+ * decide which case we are in
+ */
+ Attribute classAtt = (Attribute) attributeInfo.elementAt(0);
+
+ ScoreSet scores = new ScoreSet();
+
+ if (classAtt.isNominal() || classAtt.isString()) {
+ Enumeration enumeratedValues = classAtt.enumerateValues();
+
+ int i = 0;
+ while (enumeratedValues.hasMoreElements()) {
+ if (i >= dist.length) {
+ System.err
+ .println("WekaWrapper: Error - scores found more possible values than "
+ + "probabilities.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+ double s = dist[i];
+ String v = (String) enumeratedValues.nextElement();
+ scores.put(v, s);
+ ++i;
+ }
+ } else if (classAtt.isNumeric()) {
+ System.err.println("WekaWrapper: Error - The 'scores' function should not be called "
+ + "when the class attribute is numeric.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ } else {
+ System.err.println("WekaWrapper: Error - ScoreSet: Class Types must be either "
+ + "Nominal, String, or Numeric.");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+
+ return scores;
+ }
+
+
+ /**
+ * Indicates that the classifier is finished learning. This method must be called if the
+ * WEKA classifier is to learn anything. Since WEKA classifiers cannot learn online, all of the
+ * training lbjavaInstances must be gathered and committed to first. This method invokes the WEKA
+ * classifier's buildClassifier(Instances) method.
+ **/
+ public void doneLearning() {
+
+ checkIfCanTrain();
+ /*
+ * System.out.println("\nWekaWrapper Data Summary:");
+ * System.out.println(wekaInstances.toSummaryString());
+ */
+
+ try {
+ initializeAttributes();
+ for (LBJavaInstance i : lbjavaInstances)
+ wekaInstances.add(makeInstance(i));
+ lbjavaInstances.clear();
+
+ baseClassifier.build(new MultiLabelInstances(wekaInstances, new LabelsMetaDataImpl()));
+ } catch (Exception e) {
+ System.err.println("WekaWrapper: Error - There was a problem building the classifier");
+ if (baseClassifier == null)
+ System.out.println("WekaWrapper: baseClassifier was null.");
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ trained = true;
+ wekaInstances = new Instances(name, attributeInfo, 0);
+ wekaInstances.setClassIndex(0);
+ }
+
+ private void checkIfCanTrain() {
+ if (trained) {
+ System.err.println("WekaWrapper: Error - Cannot call 'doneLearning()' or 'learn()' again without "
+ + "first calling 'forget()'");
+ new Exception().printStackTrace();
+ System.exit(1);
+ }
+ }
+
+
+ /**
+ * Writes the settings of the classifier in use, and a string describing the classifier, if
+ * available.
+ **/
+ public void write(PrintStream out) {
+ out.print(name + ": ");
+ String[] options = baseClassifier.getOptions();
+ for (int i = 0; i < options.length; ++i)
+ out.println(options[i]);
+ out.println(baseClassifier);
+ }
+
+
+ /**
+ * Writes the learned function's internal representation in binary form.
+ *
+ * @param out The output stream.
+ **/
+ public void write(ExceptionlessOutputStream out) {
+ super.write(out);
+ out.writeBoolean(trained);
+
+ if (allowableValues == null)
+ out.writeInt(0);
+ else {
+ out.writeInt(allowableValues.length);
+ for (int i = 0; i < allowableValues.length; ++i)
+ out.writeString(allowableValues[i]);
+ }
+
+ ObjectOutputStream oos = null;
+ try {
+ oos = new ObjectOutputStream(out);
+ } catch (Exception e) {
+ System.err.println("Can't create object stream for '" + name + "': " + e);
+ System.exit(1);
+ }
+
+ try {
+ oos.writeObject(baseClassifier);
+ oos.writeObject(freshClassifier);
+ oos.writeObject(attributeInfo);
+ oos.writeObject(wekaInstances);
+ } catch (Exception e) {
+ System.err.println("Can't write to object stream for '" + name + "': " + e);
+ System.exit(1);
+ }
+ }
+
+
+ /**
+ * Reads the binary representation of a learner with this object's run-time type, overwriting
+ * any and all learned or manually specified parameters as well as the label lexicon but without
+ * modifying the feature lexicon.
+ *
+ * @param in The input stream.
+ **/
+ public void read(ExceptionlessInputStream in) {
+ super.read(in);
+ trained = in.readBoolean();
+ allowableValues = new String[in.readInt()];
+ for (int i = 0; i < allowableValues.length; ++i)
+ allowableValues[i] = in.readString();
+
+ ObjectInputStream ois = null;
+ try {
+ ois = new ObjectInputStream(in);
+ } catch (Exception e) {
+ System.err.println("Can't create object stream for '" + name + "': " + e);
+ System.exit(1);
+ }
+
+ try {
+ baseClassifier = (MultiLabelLearner) ois.readObject();
+ freshClassifier = (MultiLabelLearner) ois.readObject();
+ attributeInfo = (FastVector) ois.readObject();
+ wekaInstances = (Instances) ois.readObject();
+ } catch (Exception e) {
+ System.err.println("Can't read from object stream for '" + name + "': " + e);
+ System.exit(1);
+ }
+ }
+
+ private class LBJavaInstance{
+ private final int[] featureIndices;
+ private final double[] featureValues;
+ private final int[] labelIndices;
+ private final double[] labelValues;
+ LBJavaInstance(int[] featureIndex, double[] featureValues, int[] labelIndex, double[] labelValues) {
+ this.featureIndices = featureIndex;
+
+ this.featureValues = featureValues;
+ this.labelIndices = labelIndex;
+ this.labelValues = labelValues;
+ }
+ }
+}
+