diff --git a/build.sbt b/build.sbt index 18006336..04ba44ea 100644 --- a/build.sbt +++ b/build.sbt @@ -88,7 +88,8 @@ lazy val saulCore = (project in file("saul-core")). settings( name := "saul", libraryDependencies ++= Seq( - "com.typesafe.play" % "play_2.11" % "2.4.3" + "com.typesafe.play" % "play_2.11" % "2.4.3", + "net.sf.meka" % "meka" % "1.9.0" ) ).enablePlugins(AutomateHeaderPlugin) diff --git a/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java b/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java new file mode 100644 index 00000000..51cddba8 --- /dev/null +++ b/saul-core/src/main/java/edu/illinois/cs/cogcomp/saul/learn/SaulMulanWrapper.java @@ -0,0 +1,615 @@ +/** This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.saul.learn; + +import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessInputStream; +import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessOutputStream; +import edu.illinois.cs.cogcomp.lbjava.classify.*; +import edu.illinois.cs.cogcomp.lbjava.learn.Learner; +import mulan.classifier.MultiLabelLearner; +import mulan.classifier.transformation.BinaryRelevance; +import mulan.data.LabelsMetaDataImpl; +import mulan.data.MultiLabelInstances; +import weka.classifiers.bayes.NaiveBayes; +import weka.core.Attribute; +import weka.core.FastVector; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; + + +/** + * Translates Saul's internal problem representation into that which can be handled by WEKA + * learning algorithms. This translation involves storing all lbjavaInstances in memory so they can be + * passed to WEKA at one time. + *

+ *

+ * WEKA must be available on your CLASSPATH in order to use this class. WEKA source + * code and pre-compiled jar distributions are available at: http://www.cs.waikato.ac.nz/ml/weka/ + *

+ *

+ * To use this class in a Java application, the following restrictions must be recognized: + *

+ * + * @author Taher Rahgooy + **/ + +public class SaulMulanWrapper extends Learner { + /** + * Default for the {@link #baseClassifier} field. + */ + public static final MultiLabelLearner defaultBaseClassifier = + new BinaryRelevance(new NaiveBayes()); + + /** + * Stores the instance of the WEKA classifier which we are training; default is + * weka.classifiers.bayes.NaiveBayes. + **/ + protected MultiLabelLearner baseClassifier; + /** + * Stores a fresh instance of the WEKA classifier for the purposes of forgetting. + **/ + protected MultiLabelLearner freshClassifier; + /** + * Information about the features this learner takes as input stored here. + **/ + protected FastVector attributeInfo; + /** + * The main collection of weka Instance objects. + */ + protected Instances wekaInstances; + /** + * A buffer for lbjava instances. + */ + protected List lbjavaInstances; + /** + * Indicates whether the {@link #doneLearning()} method has been called and the + * {@link #forget()} method has not yet been called. + **/ + protected boolean trained = false; + /** + * The label producing classifier's allowable values. + */ + protected String[] allowableValues; + + + /** + * Empty constructor. Instantiates this wrapper with the default learning algorithm: + * weka.classifiers.bayes.NaiveBayes. Attribute information must be provided before + * any learning can occur. + **/ + public SaulMulanWrapper() { + this(""); + } + + /** + * Partial constructor; attribute information must be provided before any learning can occur. + * + * @param base The classifier to be used in this system. + **/ + public SaulMulanWrapper(MultiLabelLearner base) { + this("", base); + } + + /** + * Empty constructor. Instantiates this wrapper with the default learning algorithm: + * weka.classifiers.bayes.NaiveBayes. Attribute information must be provided before + * any learning can occur. + * + * @param n The name of the classifier. + **/ + public SaulMulanWrapper(String n) { + this(n, defaultBaseClassifier); + } + + + /** + * Full Constructor. + * + * @param n The name of the classifier + * @param base The classifier to be used in this system. + * have. + **/ + public SaulMulanWrapper(String n, MultiLabelLearner base) { + super(n); + baseClassifier = base; + freshClassifier = base; + lbjavaInstances = new ArrayList<>(); + } + + + /** + * This learner's output type is "mixed%". + */ + public String getOutputType() { + return "mixed%"; + } + + + /** + * Sets the labeler. + * + * @param l A labeling classifier. + **/ + public void setLabeler(Classifier l) { + super.setLabeler(l); + allowableValues = l == null ? null : l.allowableValues(); + } + + + /** + * Returns the array of allowable values that a feature returned by this classifier may take. + * + * @return The allowable values of this learner's labeler, or an array of length zero if the + * labeler has not yet been established or does not specify allowable values. + **/ + public String[] allowableValues() { + if (allowableValues == null) + return new String[0]; + return allowableValues; + } + + + /** + * Since WEKA classifiers cannot learn online, this method causes no actual learning to occur, + * it simply creates an object from this example and adds it to a set of + * lbjavaInstances from which the classifier will be built once {@link #doneLearning()} is called. + **/ + public void learn(int[] exampleFeatures, double[] exampleValues, int[] exampleLabels, + double[] labelValues) { + checkIfCanTrain(); + lbjavaInstances.add(new LBJavaInstance(exampleFeatures, exampleValues, exampleLabels, labelValues)); + } + + + /** + * This method makes one or more decisions about a single object, returning those decisions as + * Features in a vector. + * + * @param exampleFeatures The example's array of feature indices. + * @param exampleValues The example's array of feature values. + * @return A feature vector with a single feature containing the prediction for this example. + **/ + public FeatureVector classify(int[] exampleFeatures, double[] exampleValues) { + if (!trained) { + System.err.println("WekaWrapper: Error - Cannot make a classification with an " + + "untrained classifier."); + new Exception().printStackTrace(); + System.exit(1); + } + + /* + * Assuming that the first Attribute in our attributeInfo vector is the class attribute, + * decide which case we are in + */ + Attribute classAtt = (Attribute) attributeInfo.elementAt(0); + + if (classAtt.isNominal() || classAtt.isString()) { + double[] dist = getDistribution(exampleFeatures, exampleValues); + int best = 0; + for (int i = 1; i < dist.length; ++i) + if (dist[i] > dist[best]) + best = i; + + Feature label = labelLexicon.lookupKey(best); + if (label == null) + return new FeatureVector(); + String value = label.getStringValue(); + + return new FeatureVector(new DiscretePrimitiveStringFeature(containingPackage, name, + "", value, valueIndexOf(value), (short) allowableValues().length)); + } else if (classAtt.isNumeric()) { + return new FeatureVector(new RealPrimitiveStringFeature(containingPackage, name, "", + getDistribution(exampleFeatures, exampleValues)[0])); + } else { + System.err.println("WekaWrapper: Error - illegal class type."); + new Exception().printStackTrace(); + System.exit(1); + } + + return new FeatureVector(); + } + + public String discreteValue(int[] f, double[] v) { + return classify(f, v).discreteValueArray()[0]; + } + + public double realValue(int[] f, double[] v) { + return classify(f, v).realValueArray()[0]; + } + + /** + * Returns a discrete distribution of the classifier's prediction values. + * + * @param exampleFeatures The example's array of feature indices. + * @param exampleValues The example's array of feature values. + **/ + protected double[] getDistribution(int[] exampleFeatures, double[] exampleValues) { + if (!trained) { + System.err.println("WekaWrapper: Error - Cannot make a classification with an " + + "untrained classifier."); + new Exception().printStackTrace(); + System.exit(1); + } + + Instance inQuestion = + makeInstance(new LBJavaInstance(exampleFeatures, exampleValues, new int[0], new double[0])); + + /* + * For Numerical class values, this will return an array of size 1, containing the class + * prediction. For Nominal classes, an array of size equal to that of the class list, + * representing probabilities. For String classes, ? + */ + double[] dist = null; + try { + dist = baseClassifier.makePrediction(inQuestion).getConfidences(); + } catch (Exception e) { + System.err.println("WekaWrapper: Error while computing distribution."); + e.printStackTrace(); + System.exit(1); + } + + if (dist.length == 0) { + System.err.println("WekaWrapper: Error - The base classifier returned an empty " + + "probability distribution when attempting to classify an " + "example."); + new Exception().printStackTrace(); + System.exit(1); + } + + return dist; + } + + + /** + * Destroys the learned version of the WEKA classifier and empties the {@link #wekaInstances} + * collection of wekaInstances. + **/ + public void forget() { + super.forget(); + + try { + baseClassifier = freshClassifier.makeCopy(); + } catch (Exception e) { + System.err.println("LBJava ERROR: WekaWrapper.forget: Can't copy classifier:"); + e.printStackTrace(); + System.exit(1); + } + + lbjavaInstances = new ArrayList<>(); + wekaInstances = new Instances(name, attributeInfo, 0); + wekaInstances.setClassIndex(0); + trained = false; + } + + + private void initializeAttributes() { + attributeInfo = new FastVector(lexicon.size() + 1); + /* + * Here, we assume that if either the labels FeatureVector is empty of features, or is null, + * then this example is to be considered unlabeled. + */ + if (labelLexicon.size() < 1) { + System.err.println("WekaWrapper: Error - Weka Instances may only take a single class " + + "value, "); + new Exception().printStackTrace(); + System.exit(1); + } else { + Feature label = labelLexicon.lookupKey(0); + if (!label.isDiscrete()) { + Attribute a = new Attribute(label.getStringIdentifier()); + attributeInfo.addElement(a); + } else { + FastVector valueVector = new FastVector(labelLexicon.size()); + for (int v = 0; v < labelLexicon.size(); v++) + valueVector.addElement(labelLexicon.lookupKey(v).getStringValue()); + Attribute a = new Attribute(label.getGeneratingClassifier(), valueVector); + attributeInfo.addElement(a); + } + } + /* + * Construct weka attribute for each lexicon entry. + * If entry is discrete use a binary attribute + * If it is real, use a numerical attribute + */ + FastVector binaryValues = new FastVector(2); + binaryValues.addElement("0"); + binaryValues.addElement("1"); + for (int featureIndex = 0; featureIndex < lexicon.size(); ++featureIndex) { + Feature f = lexicon.lookupKey(featureIndex); + Attribute a = f.isDiscrete() ? + new Attribute(f.toString(), binaryValues) : + new Attribute(f.toString()); + + attributeInfo.addElement(a); + } + + // The first attribute is the label + wekaInstances = new Instances(name, attributeInfo, 0); + wekaInstances.setClassIndex(0); + } + + + /** + * Creates a WEKA Instance object out of a {@link FeatureVector}. + **/ + private Instance makeInstance(LBJavaInstance instance) { + + // Initialize an Instance object + Instance inst = new Instance(attributeInfo.size()); + + // Acknowledge that this instance will be a member of our dataset 'wekaInstances' + inst.setDataset(wekaInstances); + + // set all nominal feature values to 0, which means those features are not used in this example + for(int i=1; i< attributeInfo.size();i++) + if(inst.attribute(i).isNominal()) + inst.setValue(i, "0"); + + // Assign values for its attributes + /* + * Since we are iterating through this example's feature list, which does not contain the + * label feature (the label feature is the first in the 'attribute' list), we set attIndex + * to at exampleFeatures[featureIndices] + 1, while we start featureIndices at 0. + */ + for (int featureIndex = 0; featureIndex < instance.featureIndices.length; ++featureIndex) { + int attIndex = instance.featureIndices[featureIndex] + 1; + Feature f = lexicon.lookupKey(instance.featureIndices[featureIndex]); + + // if the feature does not exist, do nothing. this may occur in test set. + if (f == null) + continue; + Attribute att = (Attribute) attributeInfo.elementAt(attIndex); + + // make sure the feature and the attribute match + if (!(att.name().equals(f.toString()))) { + System.err.println("WekaWrapper: Error - makeInstance encountered a misaligned " + + "attribute-feature pair."); + System.err.println(" " + att.name() + " and " + f.toString() + + " should have been identical."); + new Exception().printStackTrace(); + System.exit(1); + } + if (f.isDiscrete()) + inst.setValue(attIndex, "1"); // this feature is used in this example so we set it to "1" + else + inst.setValue(attIndex, instance.featureValues[featureIndex]); + + } + + /* + * Here, we assume that if either the labels FeatureVector is empty of features, or is null, + * then this example is to be considered unlabeled. + */ + if (instance.labelIndices.length == 0) { + inst.setClassMissing(); + } else if (instance.labelIndices.length > 1) { + System.err.println("WekaWrapper: Error - Weka Instances may only take a single class " + + "value, "); + new Exception().printStackTrace(); + System.exit(1); + } else { + Feature label = labelLexicon.lookupKey(instance.labelIndices[0]); + + // make sure the label feature matches the n 0'th attribute + if (!(label.getGeneratingClassifier().equals(((Attribute) attributeInfo.elementAt(0)) + .name()))) { + System.err.println("WekaWrapper: Error - makeInstance found the wrong label name."); + new Exception().printStackTrace(); + System.exit(1); + } + + if (!label.isDiscrete()) + inst.setValue(0, instance.labelValues[0]); + else + inst.setValue(0, label.getStringValue()); + } + + return inst; + } + + + /** + * Produces a set of scores indicating the degree to which each possible discrete classification + * value is associated with the given example object. + **/ + public ScoreSet scores(int[] exampleFeatures, double[] exampleValues) { + double[] dist = getDistribution(exampleFeatures, exampleValues); + + /* + * Assuming that the first Attribute in our attributeInfo vector is the class attribute, + * decide which case we are in + */ + Attribute classAtt = (Attribute) attributeInfo.elementAt(0); + + ScoreSet scores = new ScoreSet(); + + if (classAtt.isNominal() || classAtt.isString()) { + Enumeration enumeratedValues = classAtt.enumerateValues(); + + int i = 0; + while (enumeratedValues.hasMoreElements()) { + if (i >= dist.length) { + System.err + .println("WekaWrapper: Error - scores found more possible values than " + + "probabilities."); + new Exception().printStackTrace(); + System.exit(1); + } + double s = dist[i]; + String v = (String) enumeratedValues.nextElement(); + scores.put(v, s); + ++i; + } + } else if (classAtt.isNumeric()) { + System.err.println("WekaWrapper: Error - The 'scores' function should not be called " + + "when the class attribute is numeric."); + new Exception().printStackTrace(); + System.exit(1); + } else { + System.err.println("WekaWrapper: Error - ScoreSet: Class Types must be either " + + "Nominal, String, or Numeric."); + new Exception().printStackTrace(); + System.exit(1); + } + + return scores; + } + + + /** + * Indicates that the classifier is finished learning. This method must be called if the + * WEKA classifier is to learn anything. Since WEKA classifiers cannot learn online, all of the + * training lbjavaInstances must be gathered and committed to first. This method invokes the WEKA + * classifier's buildClassifier(Instances) method. + **/ + public void doneLearning() { + + checkIfCanTrain(); + /* + * System.out.println("\nWekaWrapper Data Summary:"); + * System.out.println(wekaInstances.toSummaryString()); + */ + + try { + initializeAttributes(); + for (LBJavaInstance i : lbjavaInstances) + wekaInstances.add(makeInstance(i)); + lbjavaInstances.clear(); + + baseClassifier.build(new MultiLabelInstances(wekaInstances, new LabelsMetaDataImpl())); + } catch (Exception e) { + System.err.println("WekaWrapper: Error - There was a problem building the classifier"); + if (baseClassifier == null) + System.out.println("WekaWrapper: baseClassifier was null."); + e.printStackTrace(); + System.exit(1); + } + + trained = true; + wekaInstances = new Instances(name, attributeInfo, 0); + wekaInstances.setClassIndex(0); + } + + private void checkIfCanTrain() { + if (trained) { + System.err.println("WekaWrapper: Error - Cannot call 'doneLearning()' or 'learn()' again without " + + "first calling 'forget()'"); + new Exception().printStackTrace(); + System.exit(1); + } + } + + + /** + * Writes the settings of the classifier in use, and a string describing the classifier, if + * available. + **/ + public void write(PrintStream out) { + out.print(name + ": "); + String[] options = baseClassifier.getOptions(); + for (int i = 0; i < options.length; ++i) + out.println(options[i]); + out.println(baseClassifier); + } + + + /** + * Writes the learned function's internal representation in binary form. + * + * @param out The output stream. + **/ + public void write(ExceptionlessOutputStream out) { + super.write(out); + out.writeBoolean(trained); + + if (allowableValues == null) + out.writeInt(0); + else { + out.writeInt(allowableValues.length); + for (int i = 0; i < allowableValues.length; ++i) + out.writeString(allowableValues[i]); + } + + ObjectOutputStream oos = null; + try { + oos = new ObjectOutputStream(out); + } catch (Exception e) { + System.err.println("Can't create object stream for '" + name + "': " + e); + System.exit(1); + } + + try { + oos.writeObject(baseClassifier); + oos.writeObject(freshClassifier); + oos.writeObject(attributeInfo); + oos.writeObject(wekaInstances); + } catch (Exception e) { + System.err.println("Can't write to object stream for '" + name + "': " + e); + System.exit(1); + } + } + + + /** + * Reads the binary representation of a learner with this object's run-time type, overwriting + * any and all learned or manually specified parameters as well as the label lexicon but without + * modifying the feature lexicon. + * + * @param in The input stream. + **/ + public void read(ExceptionlessInputStream in) { + super.read(in); + trained = in.readBoolean(); + allowableValues = new String[in.readInt()]; + for (int i = 0; i < allowableValues.length; ++i) + allowableValues[i] = in.readString(); + + ObjectInputStream ois = null; + try { + ois = new ObjectInputStream(in); + } catch (Exception e) { + System.err.println("Can't create object stream for '" + name + "': " + e); + System.exit(1); + } + + try { + baseClassifier = (MultiLabelLearner) ois.readObject(); + freshClassifier = (MultiLabelLearner) ois.readObject(); + attributeInfo = (FastVector) ois.readObject(); + wekaInstances = (Instances) ois.readObject(); + } catch (Exception e) { + System.err.println("Can't read from object stream for '" + name + "': " + e); + System.exit(1); + } + } + + private class LBJavaInstance{ + private final int[] featureIndices; + private final double[] featureValues; + private final int[] labelIndices; + private final double[] labelValues; + LBJavaInstance(int[] featureIndex, double[] featureValues, int[] labelIndex, double[] labelValues) { + this.featureIndices = featureIndex; + + this.featureValues = featureValues; + this.labelIndices = labelIndex; + this.labelValues = labelValues; + } + } +} +