diff --git a/.gitignore b/.gitignore index b5d4ce2d6..c04ba7bae 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ core/*.iml .DS_Store Makefile +output/* + vocab.txt w2vvoc.txt SavedLSTM_tag_20190806 @@ -17,7 +19,6 @@ SentencesInfo_all_label_final_ExactRecur_ExpandBound.txt SentencesInfo_all_label_final_ExactRecur_ExpandBound_WideSubjObjBound.txt SentencesInfo_all_label_final_ExactRecur_train.csv - # sbt specific dist/* target/ diff --git a/SVM_CONTEXT_README.md b/SVM_CONTEXT_README.md new file mode 100644 index 000000000..87b7c4ac1 --- /dev/null +++ b/SVM_CONTEXT_README.md @@ -0,0 +1,27 @@ +## Training + +THE `SVMContextEngine` requires a model previously trained. To train the model from scratch the following files are required + +To run this script, you will need the training dataset found at `main/src/main/resources/org/clulab/context/svmFeatures/grouped_features.csv.gz` expected to be a gziped csv file. +Also, a list of the features to be used for training, found as a comma-separated string in `main/src/main/resources/org/clulab/context/svmFeatures/specific_features.txt` + +To train the model from scratch use the following script: + +``` +sbt 'runMain org.clulab.reach.context.svm_scripts.TrainSVMContextClassifier main/src/main/resources/org/clulab/context/svmFeatures/grouped_features.csv.gz main/src/main/resources/org/clulab/context/svmFeatures/svm_model.dat main/src/main/resources/org/clulab/context/svmFeatures/specific_features.txt' +``` + +## Usage Configuration + +The context engine is governed by the `svmContextEngine` section of `application.conf`. To indicate REACH to use the SVMContextEngine, specify `SVMPolicy` type and specify the size of the sentence window in the `bound` parameter: + +``` +contextEngine{ + type = SVMPolicy + params = { + + bound = 3 + } +} +``` +The trained model *must* exist in the file `main/src/main/resources/org/clulab/context/svmFeatures/svm_model.dat`. This path is hardcoded and is not a parameter. diff --git a/build.sbt b/build.sbt index 802ef2238..ca3d9556b 100644 --- a/build.sbt +++ b/build.sbt @@ -115,4 +115,4 @@ site.includeScaladoc() ghpages.settings -git.remoteRepo := "git@github.com:clulab/reach.git" +git.remoteRepo := "git@github.com:clulab/reach.git" \ No newline at end of file diff --git a/main/build.sbt b/main/build.sbt index a5f9f2dff..1451a5d9d 100644 --- a/main/build.sbt +++ b/main/build.sbt @@ -1,10 +1,10 @@ name := "reach-main" - libraryDependencies ++= { val akkaV = "2.5.4" val luceVer = "5.3.1" val procVer = "7.5.3" + Seq( "ai.lum" %% "nxmlreader" % "0.0.9", "commons-io" % "commons-io" % "2.4", diff --git a/main/src/main/resources/application.conf b/main/src/main/resources/application.conf index 4d93ffee0..35906a099 100644 --- a/main/src/main/resources/application.conf +++ b/main/src/main/resources/application.conf @@ -14,6 +14,7 @@ bratDir = ${rootDir}/brat # if this directory does not exist it will be created contextDir = ${rootDir}/context + # this is where the output files containing the extracted mentions will be stored # if this directory does not exist it will be created outDir = ${rootDir}/output @@ -28,7 +29,7 @@ encoding = "utf-8" # this is a list of sections that we should ignore ignoreSections = ["references", "materials", "materials|methods", "methods", "supplementary-material"] - +//ignoreSections = [] # the output formats for mentions: # "arizona" (column-based, one file per paper) # "cmu" (column-based, one file per paper) @@ -38,7 +39,7 @@ ignoreSections = ["references", "materials", "materials|methods", "methods", "su outputTypes = ["fries", "arizona", "cmu"] # number of simultaneous threads to use for parallelization -threadLimit = 2 +threadLimit = 4 # verbose logging verbose = true @@ -49,12 +50,22 @@ withAssembly = false # context engine configuration contextEngine { - type = Policy4 + type = Policy4 params = { - bound = 3 + bound = 3 } } +//contextEngine{ +// type = SVMPolicy +// params = { +// +// bound = 3 +// } +//} + + +# polairty engine configuration polarity { engine = Hybrid //Hybrid//DeepLearning //Linguistic negCountThreshold = 1 // when lower than or equal to this value, use linguistic approach in hybrid method @@ -76,6 +87,7 @@ experimentalRegulation{ keywords = main/src/main/resources/experimental_regulation_type_keywords.csv } + # grounding configuration grounding: { # List of AdHoc grounding files to insert, in order, into the grounding search sequence. @@ -93,7 +105,7 @@ grounding: { logging { # defines project-wide logging level loglevel = INFO - logfile = ${rootDir}/reach.log + logfile = ${outDir}/reach.log } # Processor Annotator choice and configuration @@ -105,7 +117,7 @@ processorAnnotator { # restart configuration restart { # restart allows batch jobs to skip over input files already successfully processed - useRestart = true + useRestart = false # restart log is one filename per line list of input files already successfully processed logfile = ${outDir}/restart.log } diff --git a/main/src/main/resources/org/clulab/context/svmFeatures/all_feature_names_file.txt b/main/src/main/resources/org/clulab/context/svmFeatures/all_feature_names_file.txt new file mode 100644 index 000000000..0c1749061 Binary files /dev/null and b/main/src/main/resources/org/clulab/context/svmFeatures/all_feature_names_file.txt differ diff --git a/main/src/main/resources/org/clulab/context/svmFeatures/grouped_features.csv.gz b/main/src/main/resources/org/clulab/context/svmFeatures/grouped_features.csv.gz new file mode 100644 index 000000000..5c15c0330 Binary files /dev/null and b/main/src/main/resources/org/clulab/context/svmFeatures/grouped_features.csv.gz differ diff --git a/main/src/main/resources/org/clulab/context/svmFeatures/specific_nondependency_featurenames.txt b/main/src/main/resources/org/clulab/context/svmFeatures/specific_nondependency_featurenames.txt new file mode 100644 index 000000000..bb6ddfb5b Binary files /dev/null and b/main/src/main/resources/org/clulab/context/svmFeatures/specific_nondependency_featurenames.txt differ diff --git a/main/src/main/resources/org/clulab/context/svmFeatures/svm_model.dat b/main/src/main/resources/org/clulab/context/svmFeatures/svm_model.dat new file mode 100644 index 000000000..c4066134e Binary files /dev/null and b/main/src/main/resources/org/clulab/context/svmFeatures/svm_model.dat differ diff --git a/main/src/main/scala/org/clulab/reach/FriesEntry.scala b/main/src/main/scala/org/clulab/reach/FriesEntry.scala index c8d184b59..ff5bfd541 100644 --- a/main/src/main/scala/org/clulab/reach/FriesEntry.scala +++ b/main/src/main/scala/org/clulab/reach/FriesEntry.scala @@ -2,7 +2,7 @@ package org.clulab.reach import ai.lum.nxmlreader.NxmlDocument - +// Another test comment, just in case case class FriesEntry( name: String, chunkId: String, diff --git a/main/src/main/scala/org/clulab/reach/RuleReader.scala b/main/src/main/scala/org/clulab/reach/RuleReader.scala index 2d0c85935..98b3b597e 100644 --- a/main/src/main/scala/org/clulab/reach/RuleReader.scala +++ b/main/src/main/scala/org/clulab/reach/RuleReader.scala @@ -5,6 +5,8 @@ import scala.io.Source /** * Utilities to read rule files */ + + object RuleReader { case class Rules(entities: String, modifications: String, events: String, context: String) diff --git a/main/src/main/scala/org/clulab/reach/context/ContextEngine.scala b/main/src/main/scala/org/clulab/reach/context/ContextEngine.scala index 50c66440e..f0a5102b9 100644 --- a/main/src/main/scala/org/clulab/reach/context/ContextEngine.scala +++ b/main/src/main/scala/org/clulab/reach/context/ContextEngine.scala @@ -32,6 +32,7 @@ object ContextEngine extends LazyLogging { val labels = mention.labels filter (contextMatching.contains(_)) + (labels.head, id) } diff --git a/main/src/main/scala/org/clulab/reach/context/ContextEngineFactory.scala b/main/src/main/scala/org/clulab/reach/context/ContextEngineFactory.scala index 1f6673a51..e8136614f 100644 --- a/main/src/main/scala/org/clulab/reach/context/ContextEngineFactory.scala +++ b/main/src/main/scala/org/clulab/reach/context/ContextEngineFactory.scala @@ -11,6 +11,7 @@ object ContextEngineFactory { val Policy2 = Value("Policy2") val Policy3 = Value("Policy3") val Policy4 = Value("Policy4") + val SVMPolicy = Value("SVMPolicy") } import Engine._ @@ -36,6 +37,10 @@ object ContextEngineFactory { case None => new BidirectionalPaddingContext } case Dummy => new DummyContextEngine + case SVMPolicy => bound match { + case w @ Some(b) => new SVMContextEngine(w) + case None => new SVMContextEngine + } case _ => new DummyContextEngine } } diff --git a/main/src/main/scala/org/clulab/reach/context/Policies.scala b/main/src/main/scala/org/clulab/reach/context/Policies.scala index 0ba87f228..bad0e9fbe 100644 --- a/main/src/main/scala/org/clulab/reach/context/Policies.scala +++ b/main/src/main/scala/org/clulab/reach/context/Policies.scala @@ -1,6 +1,8 @@ package org.clulab.reach.context +import org.clulab.reach.context.feature_utils.ContextFeatureUtils import org.clulab.reach.mentions._ + import collection.mutable @@ -36,7 +38,7 @@ class BoundedPaddingContext( // Assign the context map to the mention m.context = if(contextMap != Map.empty) Some(contextMap) else None } - + //ContextFeatureUtils.writeRowsToFile(mentions) mentions } diff --git a/main/src/main/scala/org/clulab/reach/context/RuleBasedEngine.scala b/main/src/main/scala/org/clulab/reach/context/RuleBasedEngine.scala index e8029744d..350c1f5e0 100644 --- a/main/src/main/scala/org/clulab/reach/context/RuleBasedEngine.scala +++ b/main/src/main/scala/org/clulab/reach/context/RuleBasedEngine.scala @@ -7,8 +7,8 @@ import org.clulab.reach.mentions._ abstract class RuleBasedContextEngine extends ContextEngine { // Fields - // To be overriden in the implementations. Returns a sequence of (Type, Val) features - // Feature order should be kept consisting for all return values + // To be overridden in the implementations. Returns a sequence of (Type, Val) features + // Feature order should be kept consistent for all return values var orderedContextMentions:Map[Int, Seq[BioTextBoundMention]] = _ // This is to keep the default species if necessary var defaultContexts:Option[Map[String, String]] = None @@ -26,7 +26,7 @@ abstract class RuleBasedContextEngine extends ContextEngine { // Compute default context classes // First count the context types val contextCounts:Map[(String, String), Int] = contextMentions map ContextEngine.getContextKey groupBy identity mapValues (_.size) - // Then gorup them by class + // Then group them by class val defaultContexts:Map[String, String] = contextCounts.toSeq.groupBy(_._1._1) // Sort them in decreasing order by frequency .mapValues(_.map(t => (t._1._2, t._2))) diff --git a/main/src/main/scala/org/clulab/reach/context/SVMContextEngine.scala b/main/src/main/scala/org/clulab/reach/context/SVMContextEngine.scala new file mode 100644 index 000000000..ce03b79fa --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/SVMContextEngine.scala @@ -0,0 +1,239 @@ +package org.clulab.reach.context + +import com.typesafe.scalalogging.LazyLogging +import org.clulab.context.classifiers.LinearSVMContextClassifier +import org.clulab.context.utils.ContextPairInstance +import org.clulab.reach.context.feature_utils.{ContextFeatureAggregator, ContextFeatureUtils, EventContextPairGenerator} +import org.clulab.reach.mentions.{BioEventMention, BioMention, BioTextBoundMention} + +import scala.collection.immutable + +class SVMContextEngine(sentenceWindow:Option[Int] = None) extends ContextEngine with LazyLogging { + + type Pair = (BioEventMention, BioTextBoundMention) + type EventID = String + type ContextID = (String, String) + + var paperMentions:Option[Seq[BioTextBoundMention]] = None + var orderedContextMentions:Map[Int, Seq[BioTextBoundMention]] = _ + var defaultContexts:Option[Map[String, String]] = None + + // The LinearSVMContextClassifier class is the container of our Linear SVM model. + // It contains the methods you will need for training your model and predicting on it. + // USAGE: val wrapper = new LinearSVMContextClassifier(LinearSVMClassifier[Int,String], String) + // Both parameters to the constructor are optional. You can either pass a linearSVMClassifier with customized hyper-parameters, or a string that is a path to a pre-saved model. + // If you do pass a string, the .fit function will load the model from the path and then train it. + // If you pass both, the customized LinearSVM model will be given higher precedence and will be used for training, rather than reading from file and training. + // Since both parameters are optional, you need not pass either. In that case, you can load a classifier from file, using the loadFrom(path) function + // The LinearSVMClassifier is in accordance with Scikit-learn's Liblinear classifier. Therefore, it uses .fit(dataset) to train and .predict(dataset) to test. + // Please refer to the class LinearSVMContextClassifier for further clarifications. + + val svmWrapper = new LinearSVMContextClassifier() + + // VERY IMPORTANT TO KEEP THE STARTING / for jar file path to resource dir + val resourcesPath = "/org/clulab/context/svmFeatures" + val resourcesPathToSVMModel = s"${resourcesPath}/svm_model.dat" + + + // this function call to getResource returns to us a URL that is the path to the file svm_model.dat + // the variable urlPathToSVMModel holds the value file:/home/.... + // so we need to take the shorter version of it that starts from /home/... + val urlPathToSVMModel = getClass.getResource(resourcesPathToSVMModel) + + val truncatedPathToSVM = urlPathToSVMModel.toString.replace("file:","") + + + val trainedSVMInstance = svmWrapper.loadFrom(truncatedPathToSVM) + val classifierToUse = trainedSVMInstance.classifier match { + case Some(x) => x + case None => { + null + } + } + + if(classifierToUse == null) throw new NullPointerException("No classifier found on which I can predict. Please make sure the SVMContextEngine class receives a valid Linear SVM classifier.") + + + logger.debug(s"The SVM model has been tuned to the following settings: C: ${classifierToUse.C}, Eps: ${classifierToUse.eps}, Bias: ${classifierToUse.bias}") + + override def assign(mentions: Seq[BioMention]): Seq[BioMention] = { + + paperMentions match { + // If we haven't run infer, don't modify the mentions + case None => mentions + // If we have already run infer + case Some(ctxMentions) => + + // Generate all the event/ctx mention pairs + val pairGenerator = new EventContextPairGenerator(mentions, ctxMentions) + val pairs = pairGenerator.yieldContextEventPairs() + val filteredPairs = sentenceWindow match { + case Some(bound) => + pairs.filter { + case (evt, ctx) => + Math.abs(evt.sentence - ctx.sentence) <= bound + } + case None => + pairs + } + + // The filteredPairs, as the name suggests, contains the subset of the context-event pairs, filtered based on the sentence distance window. + // A filteredPair is an instance of Pair as defined on line 15. Once we have the seq(filteredPair), we are ready to calculate the feature values. + + // To extract the feature values for a given pair, we will build an instance of ContextPairInstance, that has information about the event and context IDs, and feature values associated with that pair. + // ContextPairInstance basically is an object that contains information about paperID, eventID, contextID, and the set of features for which we need values. + // To associate each ContextPairInstance to its correct set of feature values, we will use a Map, where the map has ContextPairInstance as a key, and as value, we have a map of features with values + // for each pair, our map looks like: ContextPairInstance -> (sentenceDistance -> 0.0, dependencyDistance -> 1.0), etc. + // We can call this map the lookUpTable. + + // the line below internally calls the feature extractor on each pair and constructs the map as described. + // For more information on feature extraction, please refer to the ContextFeatureExtractor class. + val lookUpTable = ContextFeatureUtils.getFeatValMapPerInput(filteredPairs.toSet, ctxMentions) + + // In order to associate the correct ContextPairInstance to its values, we use the map from above to extract the keyset + // With this set, we can simply look up the lookUpTable and extract the correct values. + val contextPairInput:Seq[ContextPairInstance] = ContextFeatureUtils.getCtxPairInstances(lookUpTable) + + // It is now time to introduce the FeatureAggregator. The basic idea behind the FeatureAggregator is that, + // for any given (eventID, contextID) pair, it is possible that reach detected many sentences that match the pair. + // Multiple sentences for the same pair means multiple feature values for the same feature name. + // In order to avoid bias of choosing one sentence over another, we will aggregate their feature values. + // We will take the arithmetic mean, minimum, and maximum of the values. + // We will then have an instance of AggregatedContextInstance, that has 3 times the number of features of the original ContextPairInstance, + // since each feature has been aggregated to min, max and avg values. + // The SVM model will then predict on this aggregated instance. + val groupingsReadyToAggr = collection.mutable.ListBuffer[(Pair, ContextPairInstance)]() + for((eventID, contextID) <- pairs) { + val miniList = collection.mutable.ListBuffer[(Pair, ContextPairInstance)]() + val contextInstancesSubSet = contextPairInput.filter(x => ContextFeatureUtils.extractEvtId(eventID) == x.EvtID) + val contextFiltByCtxID = contextInstancesSubSet.filter(x => x.CtxID == contextID.nsId()) + for(i <- 0 until contextFiltByCtxID.size) { + val currentPair = (eventID,contextID) + val tupRow = contextFiltByCtxID(i) + val tupEntry = (currentPair, tupRow) + miniList += tupEntry + } + groupingsReadyToAggr ++= miniList + } + + val aggregatedFeatures = groupingsReadyToAggr.groupBy{ + case (pair, _) => ContextFeatureUtils.extractEvtId(pair._1) + }.mapValues{ + v => + v.groupBy(r => ContextEngine.getContextKey(r._1._2)).mapValues(s => { + val seqOfInputRowsToPass = s map (_._2) + val featureAggregatorInstance = new ContextFeatureAggregator(seqOfInputRowsToPass, lookUpTable) + val aggRow = featureAggregatorInstance.aggregateContextFeatures() + aggRow}).toSeq + } + + + val predictions:Map[EventID, Seq[(ContextID, Boolean)]] = { + val map = collection.mutable.HashMap[EventID, Seq[(ContextID, Boolean)]]() + for((k,a) <- aggregatedFeatures) { + + val x = a.map { + // this loop finds the prediction of the SVM on a given AggregatedFeature row. + // The prediction is based on LinearSVMClassifier provided by Clulab/Processors, + // which returns 1 for true and 0 for false predictions. We then convert this to the appropriate boolean values. + case (ctxId, aggregatedFeature) => + val predArrayIntForm = trainedSVMInstance.predict(Seq(aggregatedFeature)) + + + // It may be that we may need the aggregated instances for further analyses, like testing or cross-validation. + // Should such a need arise, you can write the aggregated instances to file by uncommenting the following line(s) + // there are multiple signatures to this function, please refer to the definition of ContextFeatureUtils.writeAggrRowToFile for more details + // Example usages: + // ContextFeatureUtils.writeAggRowToFile(aggregatedFeature,k.toString, ctxId._2, parentDirToWriteAllRows) --> This signature writes the AggregatedInstance to file whose path is specified by parentDirToWriteAllRows + // ContextFeatureUtils.writeAggRowToFile(aggregatedFeature, k.toString, ctxId._2,sentWind, whereToWriteRowBySentDist) + // Please note that this function writes aggregated rows for each (eventID, contextID) pair. Therefore, you may have a large number of files written to your directory. + val prediction = { + predArrayIntForm(0) match { + case 1 => true + case 0 => false + case _ => false + } + } + + logger.debug(s"For the paper ${aggregatedFeature.PMCID}, event ID: ${k.toString} and context ID: ${ctxId._2}, we have prediction: ${predArrayIntForm(0)}") + + (ctxId, prediction) + } + + val entry = Map(k -> x) + map ++= entry + + } + map.toMap + } + + // Loop over all the mentions to generate the context dictionary + for(mention <- mentions) yield { + mention match { + // If is an event mention, it's subject to context + case evt: BioEventMention => + // Get its ID + val evtId = ContextFeatureUtils.extractEvtId(evt) + // fetch its predicted pairs + val contexts = predictions.getOrElse(evtId, Seq.empty) + + // collect only those context labels that have been predicted to be true and + // associate it to the global map of context mentions of the given paper. + val contextMap = + (contexts collect { + case (ctx, true) => ctx + } groupBy (_._1)).mapValues(x => x.map(_._2)) + + val appendedContextMap = collection.mutable.HashMap[String, Seq[String]]() + appendedContextMap ++= contextMap + + + // adding a check to see if our context map has a species map, and if not, + // adding the species values to the context map before we return it + if(!contextMap.keySet.contains("Species") + && defaultContexts.isDefined) { + val defaults = defaultContexts.get + if(defaults.keySet.contains("Species")){ + appendedContextMap += ("Species" -> Array(defaults("Species"))) + } + } + + + // Assign the context map to the mention + evt.context = if(appendedContextMap != Map.empty) Some(appendedContextMap.toMap) else None + + + + // Return the modified event + evt + // If it's not an event mention, leave it as is + case m: BioMention => + m + } + } + } + } + + // Pre-filter the context mentions + override def infer(mentions: Seq[BioMention]): Unit = { + val contextMentions = mentions filter ContextEngine.isContextMention map (_.asInstanceOf[BioTextBoundMention]) + paperMentions = Some(contextMentions) + + // code from rule based engine + val entries = contextMentions groupBy (m => m.sentence) + orderedContextMentions = immutable.TreeMap(entries.toArray:_*) + val contextCounts:Map[(String, String), Int] = contextMentions map ContextEngine.getContextKey groupBy identity mapValues (_.size) + val defaultContexts:Map[String, String] = contextCounts.toSeq.groupBy(_._1._1) + // Sort them in decreasing order by frequency + .mapValues(_.map(t => (t._1._2, t._2))) + // And pick the id with of the type with highest frequency + .mapValues(l => l.maxBy(_._2)._1) + this.defaultContexts = Some(defaultContexts) + } + + override def update(mentions: Seq[BioMention]): Unit = () + + + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/base_classifiers/ContextClassifier.scala b/main/src/main/scala/org/clulab/reach/context/base_classifiers/ContextClassifier.scala new file mode 100644 index 000000000..21bc6a2bc --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/base_classifiers/ContextClassifier.scala @@ -0,0 +1,11 @@ +package org.clulab.context.classifiers + +import org.clulab.context.utils.AggregatedContextInstance + +trait ContextClassifier{ + def fit(xTrain: Seq[AggregatedContextInstance]): Unit + + def predict(xTest: Seq[AggregatedContextInstance]):Array[Int] + def saveModel(fileName: String): Unit + def loadFrom(fileName: String):LinearSVMContextClassifier +} diff --git a/main/src/main/scala/org/clulab/reach/context/base_classifiers/DummyClassifier.scala b/main/src/main/scala/org/clulab/reach/context/base_classifiers/DummyClassifier.scala new file mode 100644 index 000000000..a1b10b5e9 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/base_classifiers/DummyClassifier.scala @@ -0,0 +1,14 @@ +package org.clulab.context.classifiers + +import org.clulab.context.utils.AggregatedContextInstance +import org.clulab.context.utils.AggregatedContextInstance + +object DummyClassifier extends ContextClassifier { + override def fit(xTrain: Seq[AggregatedContextInstance]):Unit = () + + override def predict(xTest: Seq[AggregatedContextInstance]): Array[Int] = List.fill(xTest.size)(1).toArray + + override def saveModel(fileName: String): Unit = () + + override def loadFrom(fileName: String): LinearSVMContextClassifier = null +} diff --git a/main/src/main/scala/org/clulab/reach/context/base_classifiers/LinearSVMContextClassifier.scala b/main/src/main/scala/org/clulab/reach/context/base_classifiers/LinearSVMContextClassifier.scala new file mode 100644 index 000000000..9f08f8541 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/base_classifiers/LinearSVMContextClassifier.scala @@ -0,0 +1,147 @@ +package org.clulab.context.classifiers + +import java.io._ + +import com.typesafe.config.ConfigFactory +import org.clulab.context.utils.AggregatedContextInstance +import org.clulab.struct.Counter +import org.clulab.learning._ +import org.clulab.reach.context.utils.svm_training_utils.DatatypeConversionUtils +case class LinearSVMContextClassifier(classifier: Option[LinearSVMClassifier[Int,String]] = None, pathToClassifier:Option[String] = None) extends ContextClassifier { + override def fit(xTrain: Seq[AggregatedContextInstance]): Unit = { + val trainingLabels = DatatypeConversionUtils.convertOptionalToBool(xTrain) + val labelsToInt = DatatypeConversionUtils.convertBooleansToInt(trainingLabels) + val tups = constructTupsForRVF(xTrain) + val (trainDataSet, _) = mkRVFDataSet(labelsToInt,tups) + fit(trainDataSet) + } + + + // This class provides the basic API for training and predicting of a LinearSVM model. + // It also provides functions for saving LinearSVM models to file and reading from file. + + + // The function checkForNullException checks for the event that we are trying to call predict or fit on an empty model + private def checkForNullException(classForFunct: Option[LinearSVMClassifier[Int,String]], pathForFunct:Option[String]): Option[LinearSVMClassifier[Int,String]] = { + classForFunct match { + case Some(c) => Some(c) + case None => { + pathForFunct match { + case Some(s) => { + val loadedWrapper = loadFrom(s) + loadedWrapper.classifier} + case None => None + } + } + } + } + + def fit(xTrain: RVFDataset[Int, String]):Unit = { + val classifierToTrain = checkForNullException(classifier, pathToClassifier) + classifierToTrain match { + case Some(c) => c.train(xTrain) + case None => println("ERROR: The Linear SVM model has not been trained yet, since default null parameters were detectected in the custructor. However, you can fit the model by loading it from file, using the loadFrom function.") + } + + } + + override def predict(data: Seq[AggregatedContextInstance]): Array[Int] = { + val (_, individualRows) = dataConverter(data) + val classifierToPredict = checkForNullException(classifier, pathToClassifier) + classifierToPredict match { + case Some(c) => individualRows.map(c.classOf(_)) + case None => { + println("ERROR: No valid classifier was found on which I could predict. Please ensure you are passing a valid LinearSVM classifier, or a path to a classifier. I am now returning a default array of 0s") + Array.fill(individualRows.size)(0) + } + } + } + + def predict(testDatum:RVFDatum[Int, String]):Int = { + val classifierToPredict = checkForNullException(classifier, pathToClassifier) + classifierToPredict match { + case Some(c) => c.classOf(testDatum) + case None => { + println("I cannot predict the current datapoint on an empty classifier. Returning a default value of 0") + 0 + } + } + } + + // writes given model wrapper instance (LinearSVMContextClassifier) to file + override def saveModel(fileName: String): Unit = { + val os = new ObjectOutputStream(new FileOutputStream(fileName)) + os.writeObject(this) + os.close() + } + + + // reads the given file path and returns an instance of LinearSVMContextClassifier + // the classifier in the LinearSVMContextClassifier instance can be accessed by the classifier field, like instance.classifier + // please note that the classifier is an Option[LinearSVMClassifier]. This can be easily unwrapped using basic scala pattern matching. + override def loadFrom(fileName: String): LinearSVMContextClassifier = { + val is = new ObjectInputStream(new FileInputStream(fileName)) + val c = is.readObject().asInstanceOf[LinearSVMContextClassifier] + is.close() + c + } + + + // ******** Starting functions to convert data from AggregatedContextInstance to RVFDataSet, a useful format for using the in-house LinearSVMModel designed by Mihai's team. + // Consider features as pairs of (feature name, feature value) + private def mkRVFDatum[L](label:L, features:Array[(String, Double)]):RVFDatum[L, String] = { + // In here, Counter[T] basically works as a dictionary, and String should be the simplest way to implement it + // when you call c.incrementCount, you basically assign the feature called "featureName", the value in the second parameter ("inc") + val c = new Counter[String] + // In this loop we go through all the elements in features and initialize the counter with the values. It's weird but that's the way it was written + for((featureName, featureValue) <- features) c.incrementCount(featureName, inc = featureValue) + // Just changed the second type argument to string here. Label is the class, so, L can be Int to reflext 1 or 0 + new RVFDatum[L, String](label, c) + } + + // Here I made the changes to reflect my comments above. + def mkRVFDataSet(labels: Array[Int], dataSet:Array[Array[(String, Double)]]):(RVFDataset[Int, String], Array[RVFDatum[Int, String]]) = { + + val dataSetToReturn = new RVFDataset[Int, String]() + val datumCollect = collection.mutable.ListBuffer[RVFDatum[Int, String]]() + val tupIter = dataSet zip labels + for((d,l) <- tupIter) { + val currentDatum = mkRVFDatum(l,d) + dataSetToReturn += currentDatum + datumCollect += currentDatum + + } + (dataSetToReturn, datumCollect.toArray) + } + + def constructTupsForRVF(rows: Seq[AggregatedContextInstance]):Array[Array[(String, Double)]] = { + val toReturn = collection.mutable.ListBuffer[Array[(String,Double)]]() + rows.map(r => { + val featureVals = r.featureGroups + val featureName = r.featureGroupNames + val zipped = featureName zip featureVals + toReturn += zipped + }) + toReturn.toArray + } + + def dataConverter(data:Seq[AggregatedContextInstance], existingLabels: Option[Array[Int]] = None):(RVFDataset[Int, String], Array[RVFDatum[Int, String]]) = { + val tups = constructTupsForRVF(data) + val labels = existingLabels match { + case None => createLabels(data) + case Some(x) => x } + val result = mkRVFDataSet(labels, tups) + result + } + + + // This function is useful for converting your boolean labels to equivalent integer labels. 1 for true and 0 for false. + def createLabels(data:Seq[AggregatedContextInstance]):Array[Int] = { + val currentTruthTest = DatatypeConversionUtils.convertOptionalToBool(data) + val currentTruthTestInt = DatatypeConversionUtils.convertBooleansToInt(currentTruthTest) + currentTruthTestInt + } + // ************* Ending functions to convert AggregatedInstance to RVFDataset + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/svm_scripts/TrainSVMContextClassifier.scala b/main/src/main/scala/org/clulab/reach/context/svm_scripts/TrainSVMContextClassifier.scala new file mode 100644 index 000000000..735b31832 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/svm_scripts/TrainSVMContextClassifier.scala @@ -0,0 +1,88 @@ +package org.clulab.reach.context.svm_scripts +import org.clulab.context.utils.AggregatedContextInstance +import org.clulab.context.classifiers.{DummyClassifier, LinearSVMContextClassifier} +import org.clulab.learning.LinearSVMClassifier +import org.clulab.reach.context.utils.svm_training_utils.{DatatypeConversionUtils, IOUtilsForFeatureName} +import org.clulab.reach.context.utils.feature_utils.FeatureNameProcessor +class TrainSVMContextClassifier(pathToDataframe:String, pathToFileToSaveSVMModel: String, pathToSpecificFeaturenames:String){ + + // The SVM instance is set up for these hyper-parameters, the values for which were found by the Reach team through trial and error, to yield the best scores without losing generalizability. + // You can change the values of the hyper-parameters to this classifier instance before you run the script + val SVMClassifier = new LinearSVMClassifier[Int, String](C = 0.001, eps = 0.001, bias = false) + val svmInstance = new LinearSVMContextClassifier(Some(SVMClassifier)) + + + val (allFeatures, dataPoints) = IOUtilsForFeatureName.loadAggregatedRowsFromDataFrame(pathToDataframe, pathToSpecificFeaturenames) + val nonNumericFeatures = Seq("PMCID", "label", "EvtID", "CtxID", "") + val numericFeatures = allFeatures.toSet -- nonNumericFeatures.toSet + val featureDict = FeatureNameProcessor.createFeatureTypeDictionary(numericFeatures.toSeq) + // the best feature set was found through an ablation study, wherein we ran all possible combinations of features, and examining what combination yielded the best scores. + // The resulting best feature set was to be the collection of non-dependency and context dependency features. + val bestFeatureSet = featureDict("NonDep_Context") + + + // It was found that the paper PMC4204162 had no "root" to tree as per the stanford.nlp package, and hence filtered the rows out. + // if you choose to remove this filter, please note your SVM might have a different performance than observed by the team. + val trainingDataPrior = dataPoints.filter(_.PMCID != "b'PMC4204162'") + val trainingData = extractDataByRelevantFeatures(bestFeatureSet, trainingDataPrior) + svmInstance.fit(trainingData) + svmInstance.saveModel(pathToFileToSaveSVMModel) + + + def extractDataByRelevantFeatures(featureSet:Seq[String], data:Seq[AggregatedContextInstance]):Seq[AggregatedContextInstance] = { + val result = data.map(d => { + val currentSent = d.sentenceIndex + val currentPMCID = d.PMCID + val currentEvtId = d.EvtID + val currentContextID = d.CtxID + val currentLabel = d.label + val currentFeatureName = d.featureGroupNames + val currentFeatureValues = d.featureGroups + val indexList = collection.mutable.ListBuffer[Int]() + featureSet.map(f => { + if(currentFeatureName.contains(f)) { + val tempIndex = currentFeatureName.indexOf(f) + indexList += tempIndex + } + }) + val valueList = indexList.map(i => currentFeatureValues(i)) + AggregatedContextInstance(currentSent, currentPMCID, currentEvtId, currentContextID, currentLabel, valueList.toArray, featureSet.toArray) + }) + result + } +} + + +object TrainSVMContextClassifier extends App { + if(args.length == 0) + throw new IllegalArgumentException("This script takes arguments from the command line to run, but none were received. Please peruse examples of usage of this script") + + // The purpose of this script is to train an SVM instance on a dataset chosen by the user, and write to file the trained version of the SVM instance. + + // The SVM instance has been tuned to LinearSVMClassifier[Int, String](C = 0.001, eps = 0.001, bias = false) + // If you wish to change it, you can change it in the script before running it. + + // To run this script, you will need three files: + // 1) the data set you want to use for training (please note that the code is set up to take .csv.gz file format for the dataset) + // 2) The path to the file where you want the trained SVM instance to be written to file. The file written will be in the format .dat + // 3) The list of specific feature names that identify your datapoint, and any other numerical features that have more specific feature values, such as classification of a given datapoint. + + + // usage of script: + // You can run the following line as an sbt command directly in your terminal, while you're in the root directory of the project + // sbt 'run-main org.clulab.reach.context.svm_scripts.TrainSVMContextClassifier /../..path/to/dataset.csv.gz ../../path/to/output_model.dat ../path/to/specific_features.txt' + // The dataset.csv.gz file needs to exist. dataset is a custom name to the file, you can name your dataset how you choose. + // The output_model.dat file does *not* need to exist, this script will create the file for you with that name. You can name the file to your liking, but the output format will always be a .dat file + // The specific_features.txt *needs* to exist. The typical entries that go into this file are any identifiers of your datapoint, or any features that are *non-numeric*, such as identifiers. + // This 3rd file is a .txt file, that has the names of the "specific" features, separated by commas + + val cmndLinePathToDataFrame = args(0) + val cmndLinePathToWriteSVMTo = args(1) + val cmndLinePathToSpecificFeatures = args(2) + val trainSVMContextClassifier: TrainSVMContextClassifier = new TrainSVMContextClassifier(cmndLinePathToDataFrame, cmndLinePathToWriteSVMTo, cmndLinePathToSpecificFeatures) + + + // If you choose to use this script programmatically, i.e., manually create the instance of the training code, you can do so using the following line: + // val svmTrainingInstance = new TrainSVMContextClassifier(dataset.csv.gz, output_model.dat, specific_features.txt) + // Please refer to the corresponding test script for a more detailed example on the programmatic usage of the trainer. +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/AggregatedContextInstance.scala b/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/AggregatedContextInstance.scala new file mode 100644 index 000000000..d95939915 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/AggregatedContextInstance.scala @@ -0,0 +1,60 @@ +package org.clulab.context.utils + +import scala.collection.mutable +case class AggregatedContextInstance( + sentenceIndex:Int, + PMCID:String, + EvtID: String, + CtxID: String, + label: Option[Boolean], + featureGroups: Array[Double], + featureGroupNames:Array[String]) + + +object AggregatedContextInstance { + def apply(str: String, headers: Seq[String], allOtherFeatures: Set[String], indices: Map[String, Int], listOfSpecificFeatures: Array[String]):AggregatedContextInstance = { + val rowData = str.split(",") + val sentencePos = rowData(0).toInt + var evt_dependencyTails = new mutable.ListBuffer[Double] + var ctx_dependencyTails = new mutable.ListBuffer[Double] + var evt_dependencyFeatures = new mutable.ListBuffer[String] + var ctx_dependencyFeatures = new mutable.ListBuffer[String] + val featureGroups = new mutable.ListBuffer[Double] + val featureNames = new mutable.ListBuffer[String] + allOtherFeatures foreach { + case evt:String if evt.startsWith("evtDepTail") => + if(rowData(indices(evt)) != "0.0") + {evt_dependencyTails += (rowData(indices(evt))).toDouble + evt_dependencyFeatures += evt + } + case ctx:String if ctx.startsWith("ctxDepTail") => + if(rowData(indices(ctx)) != "0.0") + { + ctx_dependencyTails += (rowData(indices(ctx))).toDouble + ctx_dependencyFeatures += ctx + } + case _ => 0.0 + } + + + + val pmcid = rowData(indices("PMCID")) + + val evt = rowData(indices("EvtID")) + val ctx = rowData(indices("CtxID")) + val label = rowData(indices("label")) + + val listOfNumericFeatures = listOfSpecificFeatures.drop(4) + featureNames ++= listOfNumericFeatures + listOfNumericFeatures.map(l => { + val tempVal = rowData(indices(l)) + featureGroups += tempVal.toDouble + }) + + featureGroups ++= evt_dependencyTails + featureGroups ++= ctx_dependencyTails + featureNames ++= evt_dependencyFeatures + featureNames ++= ctx_dependencyFeatures + AggregatedContextInstance(sentencePos, pmcid, evt, ctx, Some(label.toBoolean), featureGroups.toArray, featureNames.toArray) + } +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/ContextPairInstance.scala b/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/ContextPairInstance.scala new file mode 100644 index 000000000..2dd5bd8b8 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/datatype_utils/ContextPairInstance.scala @@ -0,0 +1,90 @@ +package org.clulab.context.utils + +import java.io.InputStream + +import org.clulab.reach.context.utils.svm_training_utils.IOUtilsForFeatureName + +import scala.collection.mutable +import scala.io.Source +case class ContextPairInstance( + sentenceIndex:Int, + PMCID:String, + + label: Option[Boolean], + EvtID: String, + CtxID: String, + specificFeatureNames:Array[String], + ctx_dependencyTails:Set[String], + evt_dependencyTails:Set[String] + ) + +object ContextPairInstance{ + + val resourcesPath = "/org/clulab/context/svmFeatures" + + + val pathToSpecificNonDepFeatures = s"${resourcesPath}/specific_nondependency_featurenames.txt" + val urlToSpecificNonDependFeaturesFile = getClass.getResource(pathToSpecificNonDepFeatures) + // this function call to getResource returns to us a URL that is the path to the file svm_model.dat + // the variable urlToSpecificNonDependFeaturesFile holds the value file:/home/.... + // so we need to take the shorter version of it that starts from /home/... + val truncatedPathToSpecificNonDep = urlToSpecificNonDependFeaturesFile.toString.replace("file:","") + val listOfSpecificNonDependFeatures = IOUtilsForFeatureName.readSpecificNonDependencyFeatureNames(truncatedPathToSpecificNonDep) + private def allOtherFeatures(headers:Seq[String]): Set[String] = headers.toSet -- (listOfSpecificNonDependFeatures ++ Seq("")) + + private def indices(headers:Seq[String]): Map[String, Int] = headers.zipWithIndex.toMap + + def apply(str:String, headers: Seq[String], allOtherFeatures:Set[String], indices:Map[String, Int]):ContextPairInstance = { + // Parse the commas into tokens + val rowData = str.split(",") + val sentencePos = rowData(0).toInt + + + var evt_dependencyTails = new mutable.HashSet[String] + var ctx_dependencyTails = new mutable.HashSet[String] + + allOtherFeatures foreach { + case evt:String if evt.startsWith("evtDepTail") => + if(rowData(indices(evt)) != "0.0") + evt_dependencyTails += evt.substring(11) + case ctx:String if ctx.startsWith("ctxDepTail") => + if(rowData(indices(ctx)) != "0.0") + ctx_dependencyTails += ctx.substring(11) + case _ => () + } + + val pmcid = rowData(indices("PMCID")) + val label = rowData(indices("label")) + val evt = rowData(indices("EvtID")) + val ctx = rowData(indices("CtxID")) + + val specificFeatureNames = collection.mutable.ListBuffer[String]() + val listOfNumericFeatures = listOfSpecificNonDependFeatures.drop(4) + listOfNumericFeatures.map(l => { + specificFeatureNames += l + }) + ContextPairInstance(sentencePos, + pmcid, + Some(label.toBoolean), + evt, + ctx, + specificFeatureNames.toArray, + ctx_dependencyTails.toSet, + evt_dependencyTails.toSet) + } + + + // This fromStreamFunction accepts dataframe to be passed as an inputstream, + // and returns an instance of Seq(ContextPairInstance), wherein each of the data point is considered a *ContextPairInstance* + def fromStream(stream:InputStream):Seq[ContextPairInstance] = { + val source = Source.fromInputStream(stream) + val lines = source.getLines() + val headers = lines.next() split "," + val features = allOtherFeatures(headers) + val ixs = indices(headers) + val ret = (lines map (l => ContextPairInstance(l, headers.toSeq, features, ixs))).toList + source.close() + ret + } + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/BinnedDistance.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/BinnedDistance.scala new file mode 100644 index 000000000..cffc9355a --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/BinnedDistance.scala @@ -0,0 +1,6 @@ +package org.clulab.reach.context.feature_utils + +// this object is used in the feature extractor +object BinnedDistance extends Enumeration{ + val SAME, CLOSE, FAR = Value +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureAggregator.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureAggregator.scala new file mode 100644 index 000000000..360cec87c --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureAggregator.scala @@ -0,0 +1,141 @@ +package org.clulab.reach.context.feature_utils + +import org.clulab.context.utils.{AggregatedContextInstance, ContextPairInstance} +import org.clulab.reach.context.utils.feature_utils.FeatureNameProcessor +import org.clulab.reach.context.utils.svm_performance_utils.ScoresUtils + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +class ContextFeatureAggregator(instances:Seq[ContextPairInstance], featValLookUp:Map[ContextPairInstance, (Map[String,Double],Map[String,Double],Map[String,Double])]) { + val featureSetNames = collection.mutable.ListBuffer[String]() + val featureSetValues = collection.mutable.ListBuffer[Double]() + def aggregateContextFeatures():AggregatedContextInstance = { + + val label = None + + //val inputRows = instances + val featNameToVals = collection.mutable.Map[String,mutable.ListBuffer[Double]]() + // we are using the same set of features over all the ContextPairInstance instances, hence using the first ContextPairInstance in the sequence to get the names of the features is a safe step. + val specfeatureNamesToUse = instances(0).specificFeatureNames + val ctxFeatureNamesToUse = instances(0).ctx_dependencyTails + val evtFeatureNamesToUse = instances(0).evt_dependencyTails + + // we read through the ContextPairInstance values and add them to a name -> list of features map. + // So for a given feature name as key, we will have a list of double as values, where each double is the value to the feature in a given ContextPairInstance. + + for(in <- instances) { + val (specificVals, evtVals, ctxVals) = featValLookUp(in) + for((spec,value)<-specificVals) { + if(featNameToVals.contains(spec)) { + val currentList = featNameToVals(spec) + currentList += value + featNameToVals(spec) = currentList + } + else { + val toAddVal = collection.mutable.ListBuffer[Double]() + toAddVal += value + featNameToVals ++= Map(spec -> toAddVal) + } + } + + for((spec,value)<-evtVals) { + if(featNameToVals.contains(spec)) { + val currentList = featNameToVals(spec) + currentList += value + featNameToVals(spec) = currentList + } + else { + val toAddVal = collection.mutable.ListBuffer[Double]() + toAddVal += value + featNameToVals ++= Map(spec -> toAddVal) + } + } + + for((spec,value)<-ctxVals) { + if(featNameToVals.contains(spec)) { + val currentList = featNameToVals(spec) + currentList += value + featNameToVals(spec) = currentList + } + else { + val toAddVal = collection.mutable.ListBuffer[Double]() + toAddVal += value + featNameToVals ++= Map(spec -> toAddVal) + } + } + } + val aggregatedSpecVals = aggregateInputRowFeatValues(specfeatureNamesToUse, featNameToVals.toMap) + val aggregatedctxDepVals = aggregateInputRowFeatValues(ctxFeatureNamesToUse.toSeq, featNameToVals.toMap) + val aggregatedevtDepVals = aggregateInputRowFeatValues(evtFeatureNamesToUse.toSeq, featNameToVals.toMap) + val specFeatVal = featureValuePairing(aggregatedSpecVals) + val ctxFeatVal = featureValuePairing(aggregatedctxDepVals) + val evtFeatVal = featureValuePairing(aggregatedevtDepVals) + + addAggregatedOnce(specFeatVal) + addAggregatedOnce(ctxFeatVal) + addAggregatedOnce(evtFeatVal) + + // In the ContextPairInstance, we had filtered the feature names into three separate feature sets: non-dependency, ctx-dependency and evt-dependency. + // But in the AggregatedContextInstance, we will add them all into one list, called featureSetNames. + // featureSetNames and featureSetValues are of the same sizes. Both lists are created such that feature name at index i in featureSetNames has its corresponding value at index i in featureSetValues + val newAggRow = AggregatedContextInstance(0, instances(0).PMCID, "", "", label, featureSetValues.toArray,featureSetNames.toArray) + newAggRow + } + + + + // this function takes as parameters the names of features that need to be aggregated, along with the value of the feature observed in each ContextPairInstance. + // it then aggregates each feature to find the _min, _max and _mean (arithmetic mean) of each feature. + private def aggregateInputRowFeatValues(features:Seq[String], valuesPerGivenFeature: Map[String,mutable.ListBuffer[Double]]):Map[String,(Double,Double, Double, Int)] = { + val resultingMap = collection.mutable.Map[String,(Double,Double, Double, Int)]() + for(r <- features) { + if(valuesPerGivenFeature.contains(r)) { + val valueList = valuesPerGivenFeature(r) + val min = valueList.foldLeft(Double.MaxValue)(Math.min(_,_)) + val max = valueList.foldLeft(Double.MinValue)(Math.max(_,_)) + val sum = valueList.foldLeft(0.0)(_ + _) + val tup = (min,max,sum,valueList.size) + resultingMap ++= Map(r -> tup) + } + + else { + val tup = (0.0, 0.0, 0.0, 1) + resultingMap ++= Map(r -> tup) + } + } + resultingMap.toMap + } + + + // this function just matches the feature name with its correct value, and returns them as a list of tuples, + // where each tuple is in the form of (feature_name,feature_value) + private def featureValuePairing(aggr:Map[String,(Double,Double, Double, Int)]): Seq[(String,Double)] = { + val pairings = collection.mutable.ListBuffer[(String,Double)]() + for((key,value) <- aggr) { + val extendedName = FeatureNameProcessor.extendFeatureName(key) + val minTup = (extendedName._1, value._1) + val maxTup = (extendedName._2, value._2) + val avgTup = (extendedName._3, value._3/value._4) + + val list = ListBuffer(minTup, maxTup, avgTup) + pairings ++= list + } + pairings + } + + + + // this function adds the feature name and value to the global list of feature name and value in the same order. + // It takes as input an indexable sequence of tuples of the form (feature_name, feature_value) and + // adds the feature name and corresponding value *in the same order* to the global list of feature names and values. + // the order in which they are added is crucial to the prediction of the SVM. + private def addAggregatedOnce(input: Seq[(String, Double)]):Unit = { + for((name,value) <- input) { + featureSetNames += name + featureSetValues += value + } + } + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureExtractor.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureExtractor.scala new file mode 100644 index 000000000..2ed8a8da7 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureExtractor.scala @@ -0,0 +1,450 @@ +package org.clulab.reach.context.feature_utils + +import com.typesafe.config.ConfigFactory +import org.clulab.context.utils.ContextPairInstance +import org.clulab.processors.Document +import org.clulab.reach.context.ContextEngine +import org.clulab.reach.context.utils.svm_training_utils.IOUtilsForFeatureName +import org.clulab.reach.mentions.{BioEventMention, BioTextBoundMention} +import org.clulab.struct.Interval + +import scala.util.{Failure, Success, Try} + + +// This class calculates the values of pre-set linguistic features for (context-event) pairs detected by reach in previously unseen papers. The names of features are read from file, and the given pair i.e. (BioEventMention, BioTextBoundMention) is used to calculate the values of the features +// Please contact Dr. Clayton Morrison's team for further information on the selection of the feature names. +class ContextFeatureExtractor(datum:(BioEventMention, BioTextBoundMention), contextMentions:Seq[BioTextBoundMention]){ + type Pair = (BioEventMention, BioTextBoundMention) + type EventID = String + type ContextID = (String, String) + def extractFeaturesToCalcByBestFeatSet():Map[ContextPairInstance, (Map[String,Double],Map[String,Double],Map[String,Double])] = { + val config = ConfigFactory.load() + // we need the contextSpecificDependencyFeatures file to get the names of the dependency features for which we need to calculate values. + // The same holds for specificNonDependencyFeatureNames. These are specific and much smaller in number. + // It takes a different procedure to calculate the values of these features, the details of which can be obtained below. + + val resourcesPath = "/org/clulab/context/svmFeatures" + + + val pathToSpecificNonDepFeatures = s"${resourcesPath}/specific_nondependency_featurenames.txt" + val urlToSpecificNonDep = getClass.getResource(pathToSpecificNonDepFeatures) + // this function call to getResource returns to us a URL that is the path to the file svm_model.dat + // the variable urlToSpecificNonDep holds the value file:/home/.... + // so we need to take the shorter version of it that starts from /home/... + val truncatedPathToSpecificNonDep = urlToSpecificNonDep.toString.replace("file:","") + val specificNonDepFeatureList = IOUtilsForFeatureName.readSpecificNonDependencyFeatureNames(truncatedPathToSpecificNonDep) + + val pathToAllFeatures = s"${resourcesPath}/all_feature_names_file.txt" + val urlToAllFeatures = getClass.getResource(pathToAllFeatures) + val truncatedPathToAllFeatures = urlToAllFeatures.toString.replace("file:","") + + + val numericFeaturesToCalculateValuesFor = specificNonDepFeatureList.drop(4) + val bestFeatureDict = ContextFeatureUtils.featureConstructor(truncatedPathToAllFeatures) + + // Over all the feature names that were used, an exhaustive ablation study was performed to study the best performing subset of features, + // and this was found to be the union of non-dependency features and context-dependency features. + // We will only calculate the values of this subset of features. + val bestFeatureSequence = bestFeatureDict("NonDep_Context") + val allFeaturesSequence = bestFeatureDict("All_features") + // val file + val contextFrequencyMap = calculateContextFreq(contextMentions) + val PMCID = datum._1.document.id match { + case Some(c) => c + case None => "Unknown" + } + val label = None + val sentencePos = datum._1.sentence + val evntId = ContextFeatureUtils.extractEvtId(datum._1) + val ctxId = ContextEngine.getContextKey(datum._2) + + val specificNonDepFeatureNames = collection.mutable.ListBuffer[String]() + val ctxDepFeatures = collection.mutable.ListBuffer[String]() + val evtDepFeatures = collection.mutable.ListBuffer[String]() + + + // the names of the features read from file already contained _min, _max, etc, implying that feature values have been pre-aggregated when the Linear SVM model was trained. + // In order to maintain parity, we will "unaggregate" the feature names from file, so that feature names will be the same for the pre-trained SVM model and the fresh test dataset + + def unAggregateFeatureName(features: Seq[String]): Array[String] = { + val fixedFeatureNames = collection.mutable.ListBuffer[String]() + for(f <- features) { + if(f.contains("_min") || f.contains("_max") || f.contains("_avg")) { + val subst = f.slice(0,f.length-4) + fixedFeatureNames += subst + } + else fixedFeatureNames += f + } + fixedFeatureNames.toArray + } + + + + + // we add to the list of features the specific, non-dependency features like sentenceDistance, dependencyDistance, etc. + val dependencyFeatures = unAggregateFeatureName(allFeaturesSequence).toSet -- (unAggregateFeatureName(specificNonDepFeatureList).toSet ++ Seq("")) + + + // checking if for the specific non-dependency feature names, the best feature set contains the feature name, + // then we calculate the value for that feature, else we ignore that feature. + unAggregateFeatureName(numericFeaturesToCalculateValuesFor).map(h => { + if(unAggregateFeatureName(bestFeatureSequence).contains(h)) + specificNonDepFeatureNames += h + }) + + + // checking if for the dependency feature names, the best feature set contains the feature name, + // then we calculate the value for that feature, else we ignore that feature. + // here we classify the feature names into two separate lists, one that has event dependency feature values, + // and one that has context dependency values. + dependencyFeatures foreach { + case evt:String if evt.startsWith("evtDepTail") => { + if(unAggregateFeatureName(bestFeatureSequence).contains(evt)) evtDepFeatures += evt + } + case ctx:String if ctx.startsWith("ctxDepTail")=> { + if(unAggregateFeatureName(bestFeatureSequence).contains(ctx)) ctxDepFeatures += ctx + } + } + + + + // call feature value extractor here + // we will filter the feature names into "specific features", i.e. non-dependency features, context-dependency features, and event dependency features + val specFeatVal = calculateSpecificFeatValues(datum, contextMentions, contextFrequencyMap.toMap) + val evtDepFeatVal = calculateEvtDepFeatureVals(datum) + val ctxDepFeatVal = calculateCtxDepFeatureVals(datum) + + val row = ContextPairInstance(sentencePos, + PMCID, + label, + evntId, + ctxId._2, + specificNonDepFeatureNames.toSet.toArray, + ctxDepFeatures.toSet, + evtDepFeatures.toSet) + + val entry = Map(row -> (specFeatVal, evtDepFeatVal, ctxDepFeatVal)) + + entry + } + + + + + // this function extracts the values of the specific, non-dependency features mentioned above. + // It takes as input: + // a) tuple of event and context mention, + // b) all the mentions detected by Reach for the given paper, + // c) frequency of occurrence of each feature, which will prove necessary for feature names such as context_frequency. + // It returns to us a map of strings that are the feature names, with their corresponding values. + private def calculateSpecificFeatValues(datum:(BioEventMention, BioTextBoundMention), contextMentions:Seq[BioTextBoundMention], ctxTypeFreq:Map[String,Double]):Map[String,Double] = { + val event = datum._1 + val context = datum._2 + val doc = event.document + val result = collection.mutable.Map[String,Double]() + // ****************INTEGER VALUE FEATURES BEGIN**************** + val evntId = ContextFeatureUtils.extractEvtId(datum._1) + val ctxId = ContextEngine.getContextKey(datum._2) + + val sentenceDistance = Math.abs(datum._1.sentence - datum._2.sentence) + val sentDistEntry = Map("sentenceDistance" -> sentenceDistance.toDouble) + result ++= sentDistEntry + + + val dependencyPath = constructDependencyPath(datum) + val dependencyDistance = dependencyPath match { + case Some(path) => { + path.size.toDouble} + case None => 0.0 + } + + + + val dependencyDistEntry = Map("dependencyDistance" -> dependencyDistance) + result ++= dependencyDistEntry + + val context_frequency = ctxTypeFreq(context.nsId()) + result ++= Map("context_frequency" -> context_frequency) + + + + + // Dependency tails + val evtDependencyTails = dependencyTails(event.sentence,event.tokenInterval, doc) + val ctxDependencyTails = dependencyTails(context.sentence, context.tokenInterval, doc) + // ****************INTEGER VALUE FEATURES END**************** + + + + // ****************BOOLEAN VALUE FEATURES BEGIN**************** + val evtSentenceFirstPerson = if(eventSentenceContainsPRP(doc, event)) 1.0 else 0.0 + val evtSentenceFirstPersonEntry = Map("evtSentenceFirstPerson" -> evtSentenceFirstPerson) + result ++= evtSentenceFirstPersonEntry + + val ctxSentenceFirstPerson = if(contextSentenceContainsPRP(doc, context)) 1.0 else 0.0 + val ctxSentenceFirstPersonEntry = Map("ctxSentenceFirstPerson" -> ctxSentenceFirstPerson) + result ++= ctxSentenceFirstPersonEntry + + + val evtSentencePastTense = if(eventSentenceContainsPastTense(doc, event)) 1.0 else 0.0 + result ++= Map("evtSentencePastTense" -> evtSentencePastTense) + + + val ctxSentencePastTense = if(contextSentenceContainsPastTense(doc, context)) 1.0 else 0.0 + result ++= Map("ctxSentencePastTense" -> ctxSentencePastTense) + + + val evtSentencePresentTense = if(eventSentenceContainsPresentTense(doc, event)) 1.0 else 0.0 + result ++= Map("evtSentencePresentTense" -> evtSentencePresentTense) + + val ctxSentencePresentTense = if(contextSentenceContainsPresentTense(doc, context)) 1.0 else 0.0 + result ++= Map("ctxSentencePresentTense" -> ctxSentencePresentTense) + + val closesCtxOfClass = if(isItClosestContextOfSameCategory(event, context, contextMentions)) 1.0 else 0.0 + result ++= Map("closesCtxOfClass" -> closesCtxOfClass) + + + + // Negation in context mention + val ctxNegationInTail = if(ctxDependencyTails.filter(tail => tail.contains("neg")).size > 0) 1.0 else 0.0 + result ++= Map("ctxNegationIntTail" -> ctxNegationInTail) + + + val evtNegationInTail = if(evtDependencyTails.filter(tail => tail.contains("neg")).size > 0) 1.0 else 0.0 + result ++= Map("evtNegationInTail" -> evtNegationInTail) + // ****************BOOLEAN VALUE FEATURES END**************** + + + result.toMap + } + + + // this function calculates the value of all the event dependency features. + // it takes as input the tuple of (eventID, contextID), + // and returns a map of feature_name -> feature_value, similar to the function(s) above + + private def calculateEvtDepFeatureVals(datum:(BioEventMention, BioTextBoundMention)):Map[String,Double] = { + val event = datum._1 + val doc = event.document + val evtDependencyTails = dependencyTails(event.sentence,event.tokenInterval, doc) + val evtDepStrings = evtDependencyTails.map(e => e.mkString("_")) + evtDepStrings.map(t => s"evtDepTail_$t").groupBy(identity).mapValues(_.length) + } + + + // this function calculates the value of all the context dependency features. + // it takes as input the tuple of (eventID, contextID), + // and returns a map of feature_name -> feature_value, similar to the function(s) above + private def calculateCtxDepFeatureVals(datum:(BioEventMention, BioTextBoundMention)):Map[String,Double] = { + val context = datum._2 + val doc = context.document + + val ctxDependencyTails = dependencyTails(context.sentence, context.tokenInterval, doc) + val ctxDepStrings = ctxDependencyTails.map(c => c.mkString("_")) + + ctxDepStrings.map(t => s"ctxDepTail_$t").groupBy(identity).mapValues(_.length).mapValues(_.toDouble) + } + + + + + + // ****** starting utility functions to calculate values of dependency features ********** + private def intersentenceDependencyPath(datum:(BioEventMention, BioTextBoundMention)): Option[Seq[String]] = { + def pathToRoot(currentNodeIndx:Int, currentSentInd:Int, currentDoc:Document): Seq[String] = { + val dependencies = currentDoc.sentences(currentSentInd).dependencies.get + val allRoots = dependencies.roots.toSeq + val paths = allRoots flatMap { + r => + val ps = dependencies.shortestPathEdges(r, currentNodeIndx) + ps map (sequence => sequence map (_._3)) + } + + paths.sortBy(p => p.size).head + } + val evtShortestPath = datum._1.tokenInterval.map(ix => pathToRoot(ix, datum._1.sentence, datum._1.document)).sortBy(_.size).head + val ctxShortestPath = datum._2.tokenInterval.map(ix => pathToRoot(ix, datum._2.sentence, datum._2.document)).sortBy(_.size).head + //val evtShortestPath = pathToRoot(datum._1.tokenInterval.start, datum._1.sentence, datum._1.document) + //val ctxShortestPath = pathToRoot(datum._2.tokenInterval.start, datum._2.sentence, datum._2.document) + val numOfJumps = Seq.fill(Math.abs(datum._1.sentence - datum._2.sentence))("sentenceJump") + + val first = if(datum._1.sentence < datum._2.sentence) evtShortestPath else ctxShortestPath + val second = if(datum._2.sentence < datum._1.sentence) ctxShortestPath else evtShortestPath + val selectedPath = (first.reverse ++ numOfJumps ++ second).map(POSMaker.clusterDependency) + + val bigrams = (selectedPath zip selectedPath.drop(1)).map{ case (a, b) => s"${a}_${b}" } + + Some(bigrams) + } + + + private def constructDependencyPath(datum:(BioEventMention, BioTextBoundMention)): Option[Seq[String]] = { + + if(datum._1.sentence == datum._2.sentence) { + val currentSentContents = datum._1.document.sentences(datum._1.sentence) + val dependencies = currentSentContents.dependencies.get + val (first, second) = if(datum._1.tokenInterval.start <= datum._2.tokenInterval.start) (datum._1.tokenInterval, datum._2.tokenInterval) else (datum._2.tokenInterval, datum._1.tokenInterval) + + val paths = first flatMap { + i:Int => + second flatMap { + j:Int => + val localPaths:Seq[Seq[String]] = dependencies.shortestPathEdges(i, j, ignoreDirection = true) map (s => s map (_._3)) + localPaths + } + } + val sequence = Try(paths.filter(_.size > 0).sortBy(_.size).head.map(POSMaker.clusterDependency)) + sequence match { + case Success(s) => + // make bigrams + val bigrams:Seq[String] = { + if(s.size == 1) + s + else{ + val shifted = s.drop(1) + s.zip(shifted).map{ case (a, b) => s"${a}_${b}" } + } + } + + Some(bigrams) + case Failure(e) => + println("DEBUG: Problem when extracting dependency path for features") + None + } + } + else intersentenceDependencyPath(datum) + } + + def sentenceContainsPRP(doc:Document, ix:Int):Boolean = { + val targetWords = Set("we", "us", "our", "ours", "ourselves", "i", "me", "my", "mine", "myself") + val sentence = doc.sentences(ix) + val tags = sentence.tags.get + val lemmas = sentence.lemmas.get + + val x = (tags zip lemmas) filter { + case (tag, lemma) => + tag == "PRP" && targetWords.contains(lemma) + + } + + !x.isEmpty + } + + def dependencyTails(sentence:Int, interval:Interval, doc:Document):Seq[Seq[String]] = { + + val deps = doc.sentences(sentence).dependencies.get + + def helper(nodeIx:Int, depth:Int, maxDepth:Int):List[List[String]] = { + + // Get all the edges connected to the current node as long as they don't incide into another token of the + val incoming = Try(deps.getIncomingEdges(nodeIx)) match { + case Success(edges) => edges.filter(e => !interval.contains(e._1)).toList + case Failure(e) => Nil + } + + val outgoing = Try(deps.getOutgoingEdges(nodeIx)) match { + case Success(edges) => edges.filter(e => !interval.contains(e._1)).toList + case Failure(e) => Nil + } + + val edges = incoming ++ outgoing + + if(depth == maxDepth) + Nil + else{ + edges flatMap { + e => + val label = POSMaker.clusterDependency(e._2) + val further = helper(e._1, depth+1, maxDepth) + + further match { + case Nil => List(List(label)) + case list:List[List[String]] => list map (l => label::l) + } + } + } + } + + helper(interval.start, 0, 2) ++ helper(interval.end, 0, 2) + } + + def sentenceContainsSimplePRP(doc:Document, ix:Int):Boolean = sentenceContainsSimpleTags(doc, ix, Set("PRP")) + + def sentenceContainsSimplePastTense(doc:Document, ix:Int):Boolean = sentenceContainsSimpleTags(doc, ix, Set("VBD", "VBN")) + + def sentenceContainsSimplePresentTense(doc:Document, ix:Int):Boolean = sentenceContainsSimpleTags(doc, ix, Set("VBG", "VBP", "VBZ")) + + def sentenceContainsSimpleTags(doc:Document, ix:Int, tags:Set[String]):Boolean = { + val sentence = doc.sentences(ix) + val tags = sentence.tags.get.toSet + val evidence:Iterable[Boolean] = tags map { + tag => + tags contains tag + } + + evidence.exists(identity) + } + + def eventSentenceContainsPRP(doc:Document, event:BioEventMention):Boolean = sentenceContainsPRP(doc, event.sentence) + def contextSentenceContainsPRP(doc:Document, context:BioTextBoundMention):Boolean = sentenceContainsPRP(doc, context.sentence) + def eventSentenceContainsPastTense(doc:Document, event:BioEventMention):Boolean = sentenceContainsSimplePastTense(doc, event.sentence) + def contextSentenceContainsPastTense(doc:Document, context:BioTextBoundMention):Boolean = sentenceContainsSimplePastTense(doc, context.sentence) + def eventSentenceContainsPresentTense(doc:Document, event:BioEventMention):Boolean = sentenceContainsSimplePresentTense(doc, event.sentence) + def contextSentenceContainsPresentTense(doc:Document, context:BioTextBoundMention):Boolean = sentenceContainsSimplePresentTense(doc, context.sentence) + def isItClosestContextOfSameCategory(event:BioEventMention, + context:BioTextBoundMention, + otherContexts:Iterable[BioTextBoundMention]):Boolean = { + val bounds = Seq(event.sentence, context.sentence) + val (start, end) = (bounds.min, bounds.max) + + + val filteredContexts = otherContexts.filter{ + c => + (!(c.sentence == context.sentence && c.tokenInterval == context.tokenInterval && c.nsId() == context.nsId())) + + } + assert(filteredContexts.size == otherContexts.size -1) + val interval = start to end + + if(interval.length >= 3){ + val contextCategory = context.nsId().split(":")(0) + + val contextClasses = filteredContexts.collect{ + case c if interval.contains(c.sentence) => + c.nsId().split(":")(0) + }.toList.toSet + + val ret = !contextClasses.contains(contextCategory) + + ret + } + else + true + + } + + // ****** ending utility functions to calculate values of dependency features ********** + + + // this function calculates the value of the context_frequency feature. + // it takes as input the context mentions in the paper and counts the freequency of occurrence of each context label. + // this is then returned as a map of(context_label -> frequnecy) + // + def calculateContextFreq(ctxMentions: Seq[BioTextBoundMention]):collection.mutable.Map[String, Double] = { + val contextFrequencyMap = collection.mutable.Map[String, Double]() + ctxMentions.map(f => { + val id = f.nsId() + if(contextFrequencyMap.contains(id)) { + val get = contextFrequencyMap(id) + contextFrequencyMap(id) = get + 1 + } + + else { + val newEntry = Map(id -> 1.0) + contextFrequencyMap ++= newEntry + } + }) + contextFrequencyMap + } +} + diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureUtils.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureUtils.scala new file mode 100644 index 000000000..873028fa5 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/ContextFeatureUtils.scala @@ -0,0 +1,172 @@ +package org.clulab.reach.context.feature_utils + +import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream, PrintWriter} +import org.apache.commons.io.{FilenameUtils} +import com.typesafe.config.ConfigFactory +import org.clulab.context.utils.{AggregatedContextInstance, ContextPairInstance} +import org.clulab.reach.mentions.{BioEventMention, BioTextBoundMention} + +object ContextFeatureUtils { + type Pair = (BioEventMention, BioTextBoundMention) + type EventID = String + type ContextID = (String, String) + val config = ConfigFactory.load() + + // :input :- seq(filteredPair), seq(contextMentions) + // using the inputs, this function calls the feature extractor, and receives a seq(map). To simplify this data structure, we will flatten the output to a simple map. + // :output :- map of ContextPairInstance -> (feature_name -> feature_value) + def getFeatValMapPerInput(filteredPairs: Set[Pair], ctxMentions: Seq[BioTextBoundMention]):Map[ContextPairInstance, (Map[String,Double],Map[String,Double],Map[String,Double])] = { + + val tempo = filteredPairs.map{p => + val featureExtractor = new ContextFeatureExtractor(p, ctxMentions) + featureExtractor.extractFeaturesToCalcByBestFeatSet() + } + val flattenedMap = tempo.flatMap(t=>t).toMap + flattenedMap + } + + // getCtxPairInstances takes a map of ContextPairInstance and their corresponding feature values, and returns the keyset, i.e. set[ContextPairInstance] + def getCtxPairInstances(ctxPairFeatValMap: Map[ContextPairInstance, (Map[String,Double],Map[String,Double],Map[String,Double])]): Seq[ContextPairInstance] = { + ctxPairFeatValMap.keySet.toSeq + } + + + + // the following few functions with the name writeAggRowToFile offer different signatures to the function, + // such that the aggregated rows may be written to file for further analyses + + + // This signature of writeAggRowToFile writes the AggregatedRow object to file whose path is specified by parentDir. + // This function first creates a directory with the name of the paper, and then creates a text file using the paperID, eventID and contextID + // The aggregated row is written as an object. + def writeAggRowToFile(row:AggregatedContextInstance, evtID: String, ctxString:String, parentDir:String):Unit = { + val pmcid = s"PMC${row.PMCID.split("_")(0)}" + val whichDirToWriteRow = parentDir.concat(s"${pmcid}") + val paperDir = new File(whichDirToWriteRow) + if(!paperDir.exists()) + paperDir.mkdirs() + val whereToWriteRow = whichDirToWriteRow.concat(s"/AggregatedRow_${pmcid}_${evtID}_${ctxString}.txt") + val file2 = new File(whereToWriteRow) + val os = new ObjectOutputStream(new FileOutputStream(whereToWriteRow)) + if (!file2.exists()) { + file2.createNewFile() + } + os.writeObject(row) + os.close() + } + + + + + // This following signature of writeAggRowToFile takes the aggregated row, event ID, context ID, sentence window and directory into which the file is to be written is to be written. + // In this directory, a sub-directory called "sentencewindows/$value_of_sent_window" will be created, and rows will be written to this directory. + def writeAggRowToFile(row: AggregatedContextInstance, evtID: String, ctxID: String, sentenceWindow: Int, parentDir:String):Unit = { + + val outDir = parentDir.concat(s"/sentenceWindows/${sentenceWindow}") + val outdirFile = new File(outDir) + if(!outdirFile.exists()) + outdirFile.mkdirs() + val currentPMCID = s"PMC${row.PMCID.split("_")(0)}" + val aggrRowFilePath = outDir.concat(s"/AggregatedRow_${currentPMCID}_${evtID}_${ctxID}.txt") + val aggrRowFile = new File(aggrRowFilePath) + if(!aggrRowFile.exists()) + aggrRowFile.createNewFile() + val os = new ObjectOutputStream(new FileOutputStream(aggrRowFilePath)) + os.writeObject(row) + os.close() + + } + + + + + // This function can be used to read the AggregatedRow from file. + // It takes as parameter the path to the file where the row needs to be read from, + // and returns an instance of AggregatedContextInstance + // Please note that this function DOES NOT return the specifications of the AggregatedRow, + // i.e. using this function will not give you a reference of the paperID, eventID and contextID. + // please refer to the function called *createAggRowSpecsFromFile* that returns these specifications. + def readAggRowFromFile(file: String):AggregatedContextInstance = { + val is = new ObjectInputStream(new FileInputStream(file)) + val c = is.readObject().asInstanceOf[AggregatedContextInstance] + is.close() + c + } + + + + + // This function takes as parameter the file in which an AggregatedContextInstance row is saved, + // and returns as tuple the paperID, eventID and contextID. + def createAggRowSpecsFromFile(file: File):(String, String, String) = { + val strOnly = FilenameUtils.removeExtension(file.getName) + val pmcid = strOnly.split("_")(1) + val evtID = strOnly.split("_")(2) + val ctxID = strOnly.split("_")(3) + var rem = "" + if(strOnly.split("_").size > 4) { + rem = "_".concat(strOnly.split("_")(4)) + } + val ctxID2 = ctxID.concat(rem) + val toReturn = (pmcid, evtID, ctxID2) + toReturn + } + + + + // The function extractEventID takes an event mention as the sole parameter, + // and returns a string containing information about the sentence index, token start and token end of the event interval. + // The eventId will be returned in the format "in${sentenceIndex}from${event_start_token}to${event_end_token} + // For example, if a given BioEventMention has the sentence index 3, and token start and end values of 7 and 8, + // the function will return a string that reads "in3from7to8" + // This format will be used across the SVM engine. + + def extractEvtId(evt:BioEventMention):EventID = { + val sentIndex = evt.sentence + val tokenIntervalStart = (evt.tokenInterval.start).toString() + val tokenIntervalEnd = (evt.tokenInterval.end).toString() + "in"+sentIndex+"from"+tokenIntervalStart+"to"+tokenIntervalEnd + } + + + + def featureConstructor(file:String):Map[String, Seq[String]] = { + val is = new ObjectInputStream(new FileInputStream(file)) + val headers = is.readObject().asInstanceOf[Array[String]] + val rectifiedHeaders = rectifyWrongFeatures(headers) + is.close() + createBestFeatureSet(rectifiedHeaders) + } + + + + private def rectifyWrongFeatures(headers:Seq[String]): Seq[String] = { + val result = collection.mutable.ListBuffer[String]() + headers.map(h => if(headers.indexOf(h) == 1) result += "PMCID" else result += h) + result + } + + def createBestFeatureSet(allFeatures:Seq[String]):Map[String, Seq[String]] = { + val nonNumericFeatures = Seq("PMCID", "label", "EvtID", "CtxID", "") + val numericFeatures = allFeatures.toSet -- nonNumericFeatures.toSet + val featureDict = createFeatureDictionary(numericFeatures.toSeq) + featureDict + } + + def createFeatureDictionary(numericFeatures: Seq[String]):Map[String, Seq[String]] = { + val contextDepFeatures = numericFeatures.filter(_.startsWith("ctxDepTail")) + val eventDepFeatures = numericFeatures.filter(_.startsWith("evtDepTail")) + val nonDepFeatures = numericFeatures.toSet -- (contextDepFeatures.toSet ++ eventDepFeatures.toSet) + val map = collection.mutable.Map[String, Seq[String]]() + map += ("All_features" -> numericFeatures) + map += ("Non_Dependency_Features" -> nonDepFeatures.toSeq) + map += ("NonDep_Context" -> (nonDepFeatures ++ contextDepFeatures.toSet).toSeq) + map += ("NonDep_Event" -> (nonDepFeatures ++ eventDepFeatures.toSet).toSeq) + map += ("Context_Event" -> (contextDepFeatures.toSet ++ eventDepFeatures.toSet).toSeq) + map.toMap + } + + + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/EventContextPairGenerator.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/EventContextPairGenerator.scala new file mode 100644 index 000000000..e479bc0a4 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/EventContextPairGenerator.scala @@ -0,0 +1,26 @@ +package org.clulab.reach.context.feature_utils + +import org.clulab.reach.mentions.{BioEventMention, BioMention, BioTextBoundMention} + +class EventContextPairGenerator(mentions:Seq[BioMention], ctxMentions:Seq[BioTextBoundMention]) { + + type Pair = (BioEventMention, BioTextBoundMention) + type EventID = String + type ContextID = (String, String) + // Collect the event mentions + val evtMentions = mentions collect { + case evt:BioEventMention => evt + } + + // The constructor of this class is supplied with all the mentions and BioTextBoundMentions (i.e. context IDs) + // Let the given Seq[BioTextBoundMention] be of size m + // We will filter the event mentions from the mentions (let it be of size n) + // A cross product is then generated for the context-event pairs. The resultant seq(pairs) will be of size m*n + def yieldContextEventPairs():Seq[Pair] = { + + for(evt <- evtMentions; ctx <- ctxMentions) yield (evt, ctx) + + } + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/FeatureNameProcessor.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/FeatureNameProcessor.scala new file mode 100644 index 000000000..a700d6ab3 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/FeatureNameProcessor.scala @@ -0,0 +1,67 @@ +package org.clulab.reach.context.utils.feature_utils + + + +object FeatureNameProcessor { + def fixFeatureNameInInputStream(headers:Seq[String]): Seq[String] = { + val result = collection.mutable.ListBuffer[String]() + headers.map(h => if(headers.indexOf(h) == 1) result += "PMCID" else result += h) + result + } + + def createBestFeatureSetForTraining(allFeatures:Seq[String]):Map[String, Seq[String]] = { + val nonNumericFeatures = Seq("PMCID", "label", "EvtID", "CtxID", "") + val numericFeatures = allFeatures.toSet -- nonNumericFeatures.toSet + val featureDict = createFeatureTypeDictionary(numericFeatures.toSeq) + featureDict + } + + def createFeatureTypeDictionary(numericFeatures: Seq[String]):Map[String, Seq[String]] = { + val contextDepFeatures = numericFeatures.filter(_.startsWith("ctxDepTail")) + val eventDepFeatures = numericFeatures.filter(_.startsWith("evtDepTail")) + val nonDepFeatures = numericFeatures.toSet -- (contextDepFeatures.toSet ++ eventDepFeatures.toSet) + val map = collection.mutable.Map[String, Seq[String]]() + map += ("All_features" -> numericFeatures) + map += ("Non_Dependency_Features" -> nonDepFeatures.toSeq) + map += ("NonDep_Context" -> (nonDepFeatures ++ contextDepFeatures.toSet).toSeq) + map += ("NonDep_Event" -> (nonDepFeatures ++ eventDepFeatures.toSet).toSeq) + map += ("Context_Event" -> (contextDepFeatures.toSet ++ eventDepFeatures.toSet).toSeq) + map.toMap + } + + + // we want to "unaggregate" the name of the feature + // for example, if my feature name is sentenceDistance_max, we want to take only sentenceDistance, + // because if we miss that, we will get feature names like sentenceDistance_max_min, sentencedistance_max_max, + // which is not meaningful to our SVM. + def resolveToUnaggregatedFeatureName(seq: Seq[String], take: Int):Seq[String] = { + val result = collection.mutable.ListBuffer[String]() + val ids = seq.take(take) + val numericalFeatureNames = seq.drop(take) + result ++= ids + val miniList = collection.mutable.ListBuffer[String]() + numericalFeatureNames.map(m => { + val lim = m.length-4 + + var slice = "" + if(m.contains("_max")) + m.replace("_max","") + else if(m.contains("_min")) + m.replace("_min","") + else m.replace("_avg","") + + miniList += slice + }) + result ++=miniList.toSet.toSeq + result + } + + def extendFeatureName(f:String):(String, String, String) = { + + val feat_min = s"${f}_min" + val feat_max = s"${f}_max" + val feat_avg = s"${f}_avg" + (feat_min, feat_max, feat_avg) + + } +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/POSMaker.scala b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/POSMaker.scala new file mode 100644 index 000000000..f00848bcd --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/feature_utils/POSMaker.scala @@ -0,0 +1,50 @@ +package org.clulab.reach.context.feature_utils + +// this object is used in the feature extractor to compare binned distances +// it calculates the dependence types of each entity in a given sentence, +// such as noun, verb, etc. +object POSMaker{ + def binSentenceDistance(d:Int):BinnedDistance.Value = { + if(d == 0) + BinnedDistance.SAME + else if(d <= 13) + BinnedDistance.CLOSE + else + BinnedDistance.FAR + } + + def binDependencyDistance(d:Int):BinnedDistance.Value = { + if(d <= 0) + BinnedDistance.CLOSE + else + BinnedDistance.FAR + } + + def clusterPOSTag(tag:String):String ={ + + if(tag.startsWith("NN")) + return "NN" + else if(tag.startsWith("VB")) + return "VB" + else if(Set(",", "-RRB-", ".", ":", ";", "-LRB-").contains(tag)) + return "BOGUS" + else + return tag + } + + def clusterDependency(d:String):String = { + if(d.startsWith("prep")) + "prep" + else if(d.startsWith("conj")) + "conj" + else if(d.endsWith("obj")) + "obj" + else if(d.endsWith("mod")) + "mod" + else if(d.contains("subj")) + "subj" + else + d + } + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/svm_performance_utils/ScoresUtils.scala b/main/src/main/scala/org/clulab/reach/context/utils/svm_performance_utils/ScoresUtils.scala new file mode 100644 index 000000000..686a2d6b1 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/svm_performance_utils/ScoresUtils.scala @@ -0,0 +1,70 @@ +package org.clulab.reach.context.utils.svm_performance_utils + + +object ScoresUtils { + def argMax(values:Map[Int, Double]):Int = { + var bestK = Integer.MIN_VALUE + var bestF1 = Double.MinValue + values.map(x => {if (x._2 > bestF1) {bestK = x._1; bestF1 = x._2}}) + bestK + } + + def f1(preds: Map[String, Int]): Double = { + val p = precision(preds) + val r = recall(preds) + if (p + r == 0) 0.0 + else ((2 * (p * r))/(p + r)) + } + + def precision(preds: Map[String, Int]): Double = { + if(!(preds("TP").toDouble + preds("FP").toDouble == 0.toDouble)) preds("TP").toDouble / (preds("TP") + preds("FP")).toDouble + else 0.0 + } + + def recall(preds: Map[String, Int]): Double = { + if (!(preds("TP").toDouble + preds("FN").toDouble == 0)) preds("TP").toDouble/(preds("TP") + preds("FN")).toDouble + else 0.0 + } + + + def accuracy(preds:Map[String, Int]): Double = { + if (!((preds("TP") + preds("FP") + preds("FN") + preds("TN").toDouble) == 0)) (preds("TP") + preds("TN")).toDouble/(preds("TP") + preds("TN") + preds("FP") + preds("FN")).toDouble + else 0.0 + } + + def arithmeticMeanScore(scores:Seq[Double]):Double = { + val sum = scores.foldLeft(0.0)(_ + _) + sum/scores.size + } + + def predictCounts(yTrue: Array[Int], yPred: Array[Int]): Map[String, Int] = { + val indexValuePair = yTrue zip yPred + var TP = 0; var FP = 0; var TN = 0; var FN = 0 + for((gt,pr) <- indexValuePair) { + if (gt == 1 && pr == 1) TP+=1 + if (gt == 1 && pr == 0) FN +=1 + if (gt == 0 && pr == 0) TN +=1 + if (gt == 0 && pr == 1) FP +=1 + } + Map(("TP" -> TP), ("FP" -> FP), ("TN" -> TN), ("FN" -> FN)) + } + + + def createStats(nums: Iterable[Double]): (Double, Double, Double) = { + val min = nums.min + val max = nums.max + val avg = nums.sum / nums.size + (min, max, avg) + } + + + def findAggrMetrics(seq:Seq[Double]): (Double,Double,Double) = { + val min = seq.foldLeft(Double.MaxValue)(Math.min(_,_)) + val max = seq.foldLeft(Double.MinValue)(Math.max(_,_)) + val sum = seq.foldLeft(0.0)(_+_) + val avg = sum.toDouble/seq.size.toDouble + (min,max,avg) + } + + +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/DatatypeConversionUtils.scala b/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/DatatypeConversionUtils.scala new file mode 100644 index 000000000..39dac91e6 --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/DatatypeConversionUtils.scala @@ -0,0 +1,21 @@ +package org.clulab.reach.context.utils.svm_training_utils + +import org.clulab.context.utils.AggregatedContextInstance + +object DatatypeConversionUtils { + def convertBooleansToInt(labels: Seq[Boolean]):Array[Int] = { + + val toReturn = labels.map(l => l match { + case true => 1 + case false => 0 + }) + toReturn.toArray + } + + def convertOptionalToBool(rows: Seq[AggregatedContextInstance]): Seq[Boolean] = { + rows.map(x => x.label match { + case Some(x) => x + case _ => false + }) + } +} diff --git a/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/IOUtilsForFeatureName.scala b/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/IOUtilsForFeatureName.scala new file mode 100644 index 000000000..9ca0ed89d --- /dev/null +++ b/main/src/main/scala/org/clulab/reach/context/utils/svm_training_utils/IOUtilsForFeatureName.scala @@ -0,0 +1,45 @@ +package org.clulab.reach.context.utils.svm_training_utils + +import java.io.{BufferedInputStream, FileInputStream, ObjectInputStream} +import java.util.zip.GZIPInputStream + +import org.clulab.context.utils.AggregatedContextInstance + +import org.clulab.reach.context.utils.feature_utils.FeatureNameProcessor + +import scala.io.Source + +object IOUtilsForFeatureName { + def loadAggregatedRowsFromDataFrame(groupedFeaturesFileName: String, pathToSpecificNonDepFeatures: String):(Seq[String], Seq[AggregatedContextInstance]) = { + val listOfSpecificFeatures = readSpecificNonDependencyFeatureNames(pathToSpecificNonDepFeatures) + def allOtherFeatures(headers:Seq[String]): Set[String] = headers.toSet -- (listOfSpecificFeatures ++ Seq("")) + def indices(headers:Seq[String]): Map[String, Int] = headers.zipWithIndex.toMap + val fileInputStream = new FileInputStream(groupedFeaturesFileName) + val bufferedStream = new BufferedInputStream(new GZIPInputStream(fileInputStream)) + val source = Source.fromInputStream(bufferedStream) + val lines = source.getLines() + val headers = lines.next() split "," + val rectifiedHeaders = FeatureNameProcessor.fixFeatureNameInInputStream(headers) + val features = allOtherFeatures(rectifiedHeaders) + val ixs = indices(rectifiedHeaders) + val ret = lines.map(l => AggregatedContextInstance(l, rectifiedHeaders, features, ixs, listOfSpecificFeatures)).toList + source.close() + (rectifiedHeaders, ret) + } + + + + def getTrainingFeatures(file:String):Map[String, Seq[String]] = { + val is = new ObjectInputStream(new FileInputStream(file)) + val headers = is.readObject().asInstanceOf[Array[String]] + val rectifiedHeaders = FeatureNameProcessor.fixFeatureNameInInputStream(headers) + is.close() + FeatureNameProcessor.createBestFeatureSetForTraining(rectifiedHeaders) + } + + def readSpecificNonDependencyFeatureNames(fileName: String):Array[String] = { + val is = new ObjectInputStream(new FileInputStream(fileName)) + val headers = is.readObject().asInstanceOf[Array[String]] + headers + } +} diff --git a/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in15from22to23_tissuelist:TS-0649.txt b/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in15from22to23_tissuelist:TS-0649.txt new file mode 100644 index 000000000..7b3069d5d Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in15from22to23_tissuelist:TS-0649.txt differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in233from9to11_tissuelist:TS-1224.txt b/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in233from9to11_tissuelist:TS-1224.txt new file mode 100644 index 000000000..71625a347 Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/PMC3411611/AggregatedRow_PMC3411611_in233from9to11_tissuelist:TS-1224.txt differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_tissuelist:TS-0500.txt b/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_tissuelist:TS-0500.txt new file mode 100644 index 000000000..ad64d63e4 Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_tissuelist:TS-0500.txt differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_uberon:UBERON:0000105.txt b/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_uberon:UBERON:0000105.txt new file mode 100644 index 000000000..4d26a4ddb Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_uberon:UBERON:0000105.txt differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/grouped_features.csv.gz b/main/src/test/resources/inputs/aggregated-context-instance/grouped_features.csv.gz new file mode 100644 index 000000000..5c15c0330 Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/grouped_features.csv.gz differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/specific_nondependency_featurenames.txt b/main/src/test/resources/inputs/aggregated-context-instance/specific_nondependency_featurenames.txt new file mode 100644 index 000000000..bb6ddfb5b Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/specific_nondependency_featurenames.txt differ diff --git a/main/src/test/resources/inputs/aggregated-context-instance/svm_model.dat b/main/src/test/resources/inputs/aggregated-context-instance/svm_model.dat new file mode 100644 index 000000000..c4066134e Binary files /dev/null and b/main/src/test/resources/inputs/aggregated-context-instance/svm_model.dat differ diff --git a/main/src/test/scala/org/clulab/reach/context/TestSVMContext.scala b/main/src/test/scala/org/clulab/reach/context/TestSVMContext.scala new file mode 100644 index 000000000..3fd65ada5 --- /dev/null +++ b/main/src/test/scala/org/clulab/reach/context/TestSVMContext.scala @@ -0,0 +1,184 @@ +package org.clulab.reach + + +import org.clulab.context.utils.AggregatedContextInstance +import org.scalatest.{FlatSpec, Matchers} +import java.io.{FileInputStream, ObjectInputStream} + +import org.clulab.context.classifiers.LinearSVMContextClassifier +import org.clulab.reach.PaperReader.procAnnotator +import org.clulab.reach.context.ContextEngineFactory.Engine + +class TestSVMContext extends FlatSpec with Matchers { + + lazy val reachSystemWithSVMContext = new ReachSystem(processorAnnotator = Some(procAnnotator), + contextEngineType = Engine.withName("SVMPolicy"), + contextParams = Map("bound" -> "7")) + + val resourcesPath = "/inputs/aggregated-context-instance" + + + val svmWrapper = new LinearSVMContextClassifier() + val svmInstancePath = s"${resourcesPath}/svm_model.dat" + val urlPathToSVMModel = readFileNameFromResource(svmInstancePath) + val trainedSVMInstance = svmWrapper.loadFrom(urlPathToSVMModel) + val pair1 = "PMC3411611,in233from9to11,tissuelist:TS-1224" //prediction is 1 + val resourcesPathToPair1 = s"${resourcesPath}/PMC3411611/AggregatedRow_PMC3411611_in233from9to11_tissuelist:TS-1224.txt" + val urlPathToPair1 = readFileNameFromResource(resourcesPathToPair1) + val rowForPair1 = readAggRowFromFile(urlPathToPair1) + + + "Reach System with SVM context engine" should "run correctly" in { + val text = "S6K1 phosphorylates the RPTOR protein and promotes the hydroxylation of the Pkh1 protein." + val doc = reachSystemWithSVMContext.mkDoc(text, "testdoc") + val mentions = reachSystemWithSVMContext.extractFrom(doc) + mentions should not be empty + } + + pair1 should "have prediction 1" in { + val pred = trainedSVMInstance.predict(Seq(rowForPair1))(0) + pred should be (1) + } + + pair1 should "have sentenceDistance_min of 2" in { + val sentenceDistance_min = rowForPair1.featureGroups(rowForPair1.featureGroupNames.indexOf("sentenceDistance_min")) + sentenceDistance_min should be (2.0) + } + + pair1 should "have dependencyDistance_min of 6" in { + val sentenceDistance_min = rowForPair1.featureGroups(rowForPair1.featureGroupNames.indexOf("dependencyDistance_min")) + sentenceDistance_min should be (11.0) + } + + + pair1 should "have contextFrequency_min of 33" in { + val sentenceDistance_min = rowForPair1.featureGroups(rowForPair1.featureGroupNames.indexOf("context_frequency_min")) + sentenceDistance_min should be (33.0) + } + + pair1 should "have closestContextOfClass_min of 0" in { + val sentenceDistance_min = rowForPair1.featureGroups(rowForPair1.featureGroupNames.indexOf("closesCtxOfClass_min")) + sentenceDistance_min should be (0.0) + } + + + + + val pair2 = "PMC3411611,in15from22to23,tissuelist:TS-0649" + val resourcesPathToPair2 = s"${resourcesPath}/PMC3411611/AggregatedRow_PMC3411611_in15from22to23_tissuelist:TS-0649.txt" + val urlPathToPair2 = readFileNameFromResource(resourcesPathToPair2) + val rowForPair2 = readAggRowFromFile(urlPathToPair2) + + pair2 should "have prediction 0" in { + val pred = trainedSVMInstance.predict(Seq(rowForPair2))(0) + pred should be (0) + } + + + pair2 should "have sentenceDistance_min of 1" in { + val sentenceDistance_min = rowForPair2.featureGroups(rowForPair2.featureGroupNames.indexOf("sentenceDistance_min")) + sentenceDistance_min should be (1.0) + } + + pair2 should "have dependencyDistance_min of 2" in { + val sentenceDistance_min = rowForPair2.featureGroups(rowForPair2.featureGroupNames.indexOf("dependencyDistance_min")) + sentenceDistance_min should be (2.0) + } + + + pair2 should "have contextFrequency_min of 1" in { + val sentenceDistance_min = rowForPair2.featureGroups(rowForPair2.featureGroupNames.indexOf("context_frequency_min")) + sentenceDistance_min should be(1.0) + + } + + pair2 should "have closestContextOfClass_min of 1" in { + val sentenceDistance_min = rowForPair2.featureGroups(rowForPair2.featureGroupNames.indexOf("closesCtxOfClass_min")) + sentenceDistance_min should be (1.0) + } + + + val pair3 = "PMC3608085,in195from6to12,tissuelist:TS-0500" + val resourcesPathToPair3 = s"${resourcesPath}/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_tissuelist:TS-0500.txt" + val urlPathToPair3 = readFileNameFromResource(resourcesPathToPair3) + val rowForPair3 = readAggRowFromFile(urlPathToPair3) + + pair3 should "have prediction 1" in { + val pred = trainedSVMInstance.predict(Seq(rowForPair3))(0) + pred should be (1) + } + + + pair3 should "have sentenceDistance_min of 2" in { + val sentenceDistance_min = rowForPair3.featureGroups(rowForPair3.featureGroupNames.indexOf("sentenceDistance_min")) + sentenceDistance_min should be (2.0) + } + + pair3 should "have dependencyDistance_min of 2" in { + val dependencyDistance_min = rowForPair3.featureGroups(rowForPair3.featureGroupNames.indexOf("dependencyDistance_min")) + dependencyDistance_min should be (2.0) + } + + + pair3 should "have contextFrequency_min of 26" in { + val contextFrequency_min = rowForPair3.featureGroups(rowForPair3.featureGroupNames.indexOf("context_frequency_min")) + contextFrequency_min should be(26.0) + + } + + pair3 should "have closestContextOfClass_min of 0" in { + val closestContextOfClass_min = rowForPair3.featureGroups(rowForPair3.featureGroupNames.indexOf("closesCtxOfClass_min")) + closestContextOfClass_min should be (0.0) + } + + + val pair4 = "PMC3608085,in195from6to12,uberon:UBERON:0000105" + val resourcesPathToPair4 = s"${resourcesPath}/PMC3608085/AggregatedRow_PMC3608085_in195from6to12_uberon:UBERON:0000105.txt" + val urlPathToPair4 = readFileNameFromResource(resourcesPathToPair4) + val rowForPair4 = readAggRowFromFile(urlPathToPair4) + pair4 should "have prediction 0" in { + val pred = trainedSVMInstance.predict(Seq(rowForPair4))(0) + pred should be (0) + } + + + pair4 should "have sentenceDistance_min of 1" in { + val sentenceDistance_min = rowForPair4.featureGroups(rowForPair4.featureGroupNames.indexOf("sentenceDistance_min")) + sentenceDistance_min should be (1.0) + } + + pair4 should "have dependencyDistance_min of 0" in { + val sentenceDistance_min = rowForPair4.featureGroups(rowForPair4.featureGroupNames.indexOf("dependencyDistance_min")) + sentenceDistance_min should be (0.0) + } + + + pair4 should "have contextFrequency_min of 1" in { + val sentenceDistance_min = rowForPair4.featureGroups(rowForPair4.featureGroupNames.indexOf("context_frequency_min")) + sentenceDistance_min should be(1.0) + + } + + pair4 should "have closestContextOfClass_min of 1" in { + val sentenceDistance_min = rowForPair4.featureGroups(rowForPair4.featureGroupNames.indexOf("closesCtxOfClass_min")) + sentenceDistance_min should be (1.0) + } + + + + + def readFileNameFromResource(resourcePath: String):String = { + val url = getClass.getResource(resourcePath) + val truncatedPathToSVM = url.toString.replace("file:","") + truncatedPathToSVM + } + + def readAggRowFromFile(fileName: String):AggregatedContextInstance = { + val is = new ObjectInputStream(new FileInputStream(fileName)) + val c = is.readObject().asInstanceOf[AggregatedContextInstance] + is.close() + c + } + + +} diff --git a/main/src/test/scala/org/clulab/reach/context/TestSVMTrainingScript.scala b/main/src/test/scala/org/clulab/reach/context/TestSVMTrainingScript.scala new file mode 100644 index 000000000..1b7b1cea9 --- /dev/null +++ b/main/src/test/scala/org/clulab/reach/context/TestSVMTrainingScript.scala @@ -0,0 +1,37 @@ +package org.clulab.reach.context +import java.io.{File, FileInputStream, ObjectInputStream} + +import sys.process._ +import org.clulab.reach.context.svm_scripts.TrainSVMContextClassifier + +import scala.language.postfixOps + +import org.clulab.context.utils.AggregatedContextInstance +import org.scalatest.{FlatSpec, Matchers} +class TestSVMTrainingScript extends FlatSpec with Matchers { + + + + + "SVM training script" should "create a .dat file to save the trained SVM instance to" in { + val resourcesPath = "/inputs/aggregated-context-instance" + val resourcePathToDataFrame = s"${resourcesPath}/grouped_features.csv.gz" + val urlPathToDataframe = readFileNameFromResource(resourcePathToDataFrame) + val resourcePathToSpecificFeatures = s"${resourcesPath}/specific_nondependency_featurenames.txt" + val urlPathToSpecificFeaturenames = readFileNameFromResource(resourcePathToSpecificFeatures) + // creating the path to where the output svm_model should be, using the ready URL path from another file in the same location + val pathToSVMModelToTest = urlPathToSpecificFeaturenames.replace("specific_nondependency_featurenames.txt","svm_model_temp.dat") + new TrainSVMContextClassifier(urlPathToDataframe,pathToSVMModelToTest,urlPathToSpecificFeaturenames) + val svmModelFile = new File(pathToSVMModelToTest) + val checkIfDatFileExists = svmModelFile.exists() + checkIfDatFileExists should be (true) + svmModelFile.deleteOnExit() + } + + def readFileNameFromResource(resourcePath: String):String = { + val url = getClass.getResource(resourcePath) + val truncatedPathToSVM = url.toString.replace("file:","") + truncatedPathToSVM + } + +}