diff --git a/saul-notebooks/getting_started/SpamClassifier.ipynb b/saul-notebooks/getting_started/SpamClassifier.ipynb new file mode 100644 index 00000000..4f7768c8 --- /dev/null +++ b/saul-notebooks/getting_started/SpamClassifier.ipynb @@ -0,0 +1,378 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Getting Started with Saul\n", + "\n", + "We will look at a Spam Classification task where we try to classify email documents as SPAM or HAM. This notebook will work through the steps in creating and running the Spam Classifier." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step -1 - Jupyter-Scala integration for Saul" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "34 new artifact(s)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "34 new artifacts in macro\n", + "34 new artifacts in runtime\n", + "34 new artifacts in compile\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "classpath.addRepository(\"http://cogcomp.cs.illinois.edu/m2repo\")\n", + "classpath.add(\"edu.illinois.cs.cogcomp\" %% \"saul\" % \"0.5.5\") " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 0 - Spam Data\n", + "\n", + "something about how the data looks like" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[32mimport \u001b[36mscala.io.Source\u001b[0m\n", + "\u001b[32mimport \u001b[36mjava.io.File\u001b[0m\n", + "\u001b[36mspamDataBasePath\u001b[0m: \u001b[32mString\u001b[0m = \u001b[32m\"../../data/EmailSpam/\"\u001b[0m\n", + "\u001b[36mtrainDataPath\u001b[0m: \u001b[32mString\u001b[0m = \u001b[32m\"../../data/EmailSpam/train/\"\u001b[0m\n", + "\u001b[36mtestDataPath\u001b[0m: \u001b[32mString\u001b[0m = \u001b[32m\"../../data/EmailSpam/test/\"\u001b[0m\n", + "\u001b[36mdir\u001b[0m: \u001b[32mjava\u001b[0m.\u001b[32mio\u001b[0m.\u001b[32mFile\u001b[0m = ../../data/EmailSpam/train/ham" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import scala.io.Source\n", + "import java.io.File\n", + "\n", + "val spamDataBasePath = \"../../data/EmailSpam/\"\n", + "val trainDataPath = spamDataBasePath + \"train/\"\n", + "val testDataPath = spamDataBasePath + \"test/\"\n", + "\n", + "val dir = new File(trainDataPath + \"ham\")\n", + "require(dir.exists() && dir.isDirectory())\n", + "\n", + "// val sampleDoc = dir.listFiles.filter(_.isFile).head\n", + "// $println(\"Sample Document:\")\n", + "// Source.fromFile(sampleDoc).getLines.foreach(println(_))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1 - Data Reader\n", + "\n", + "We create a reader that parses each document and parses it into required classes. \n", + "\n", + "We define an Email as a collection of words and its label." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "defined \u001b[32mclass \u001b[36mEmail\u001b[0m\n", + "defined \u001b[32mobject \u001b[36mDataReader\u001b[0m" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "case class Email(val words: Seq[String], val label: String)\n", + "\n", + "object DataReader {\n", + " def apply(dirName: String, label: String): Iterable[Email] = {\n", + " val dir = new File(dirName)\n", + " require(dir.exists() && dir.isDirectory)\n", + " \n", + " dir.listFiles\n", + " .filter(_.isFile)\n", + " .flatMap(file => parseEmail(file.getAbsolutePath, label))\n", + " }\n", + " \n", + " private def parseEmail(fileName: String, label: String): Option[Email] = {\n", + " val source = Source.fromFile(fileName)\n", + " if (source.hasNext) {\n", + " val words = source.getLines\n", + " .flatMap(_.split(\"\\\\s+\"))\n", + " .toSeq\n", + " Some(Email(words, label))\n", + " } else {\n", + " None\n", + " }\n", + " }\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2 - DataModel (Entities, Features)\n", + "\n", + "Where we define the DataModel" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[32mimport \u001b[36medu.illinois.cs.cogcomp.saul.datamodel.DataModel\u001b[0m\n", + "defined \u001b[32mobject \u001b[36mSpamDataModel\u001b[0m" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import edu.illinois.cs.cogcomp.saul.datamodel.DataModel\n", + "\n", + "object SpamDataModel extends DataModel {\n", + " val email = node[Email]\n", + " \n", + " // Features\n", + " val words = property(email) { \n", + " doc: Email => doc.words.toList\n", + " }\n", + " \n", + " val bigrams = property(email) {\n", + " doc: Email => doc.words.sliding(2).map(_.mkString(\"-\")).toList\n", + " }\n", + " \n", + " val spamLabel = property(email) {\n", + " doc: Email => doc.label\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3 - Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[32mimport \u001b[36medu.illinois.cs.cogcomp.lbjava.learn.SupportVectorMachine\u001b[0m\n", + "\u001b[32mimport \u001b[36medu.illinois.cs.cogcomp.saul.classifier.Learnable\u001b[0m\n", + "\u001b[32mimport \u001b[36mSpamDataModel._\u001b[0m\n", + "defined \u001b[32mobject \u001b[36mSpamClassifier\u001b[0m" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import edu.illinois.cs.cogcomp.lbjava.learn.SupportVectorMachine\n", + "import edu.illinois.cs.cogcomp.saul.classifier.Learnable\n", + "import SpamDataModel._\n", + "\n", + "object SpamClassifier extends Learnable(email) {\n", + " def label = spamLabel\n", + " override lazy val classifier = new SupportVectorMachine()\n", + " override def feature = using(words, bigrams)\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4 - App (Train, Test)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO [2016-11-06 18:30:11,121] cmd4$$user$SpamClassifier$: Learnable: Learn with data of size 9\n", + "INFO [2016-11-06 18:30:11,122] cmd4$$user$SpamClassifier$: Training: 30 iterations remain.\n", + "INFO [2016-11-06 18:30:11,123] cmd4$$user$SpamClassifier$: Training: 30 iterations remain.\n", + "INFO [2016-11-06 18:30:11,248] cmd4$$user$SpamClassifier$: Training: 20 iterations remain.\n", + "INFO [2016-11-06 18:30:11,348] cmd4$$user$SpamClassifier$: Training: 10 iterations remain.\n", + " Label Precision Recall F1 LCount PCount\n", + "-----------------------------------------------\n", + "ham 100.000 60.000 75.000 5 3\n", + "spam 71.429 100.000 83.333 5 7\n", + "-----------------------------------------------\n", + "Accuracy 80.000 - - - 10\n" + ] + }, + { + "data": { + "text/plain": [ + "\u001b[36mtrainData\u001b[0m: \u001b[32mIterable\u001b[0m[\u001b[32mEmail\u001b[0m] = \u001b[33mArraySeq\u001b[0m(\n", + " \u001b[33mEmail\u001b[0m(\n", + " \u001b[33mStream\u001b[0m(\n", + " \u001b[32m\"Subject:\"\u001b[0m,\n", + " \u001b[32m\"double\"\u001b[0m,\n", + " \u001b[32m\"your\"\u001b[0m,\n", + " \u001b[32m\"life\"\u001b[0m,\n", + " \u001b[32m\"insurance\"\u001b[0m,\n", + " \u001b[32m\"at\"\u001b[0m,\n", + " \u001b[32m\"no\"\u001b[0m,\n", + " \u001b[32m\"extra\"\u001b[0m,\n", + " \u001b[32m\"cost\"\u001b[0m,\n", + " \u001b[32m\"!\"\u001b[0m,\n", + " \u001b[32m\"29155\"\u001b[0m,\n", + " \u001b[32m\"the\"\u001b[0m,\n", + " \u001b[32m\"lowest\"\u001b[0m,\n", + " \u001b[32m\"life\"\u001b[0m,\n", + " \u001b[32m\"insurance\"\u001b[0m,\n", + " \u001b[32m\"quotes\"\u001b[0m,\n", + " \u001b[32m\"without\"\u001b[0m,\n", + "\u001b[33m...\u001b[0m\n", + "\u001b[36mtestData\u001b[0m: \u001b[32mIterable\u001b[0m[\u001b[32mEmail\u001b[0m] = \u001b[33mArraySeq\u001b[0m(\n", + " \u001b[33mEmail\u001b[0m(\n", + " \u001b[33mStream\u001b[0m(\n", + " \u001b[32m\"Subject:\"\u001b[0m,\n", + " \u001b[32m\"slotting\"\u001b[0m,\n", + " \u001b[32m\"order\"\u001b[0m,\n", + " \u001b[32m\"confirmation\"\u001b[0m,\n", + " \u001b[32m\"may\"\u001b[0m,\n", + " \u001b[32m\"18\"\u001b[0m,\n", + " \u001b[32m\",\"\u001b[0m,\n", + " \u001b[32m\"2004\"\u001b[0m,\n", + " \u001b[32m\"etacitne\"\u001b[0m,\n", + " \u001b[32m\"{\"\u001b[0m,\n", + " \u001b[32m\"%\"\u001b[0m,\n", + " \u001b[32m\"begin\"\u001b[0m,\n", + " \u001b[32m\"_\"\u001b[0m,\n", + " \u001b[32m\"split\"\u001b[0m,\n", + " \u001b[32m\"76\"\u001b[0m,\n", + " \u001b[32m\"%\"\u001b[0m,\n", + " \u001b[32m\"}\"\u001b[0m,\n", + "\u001b[33m...\u001b[0m\n", + "\u001b[36mres6_5\u001b[0m: \u001b[32medu\u001b[0m.\u001b[32millinois\u001b[0m.\u001b[32mcs\u001b[0m.\u001b[32mcogcomp\u001b[0m.\u001b[32msaul\u001b[0m.\u001b[32mclassifier\u001b[0m.\u001b[32mResults\u001b[0m = \u001b[33mResults\u001b[0m(\n", + " \u001b[33mArray\u001b[0m(\n", + " \u001b[33mResultPerLabel\u001b[0m(\n", + " \u001b[32m\"ham\"\u001b[0m,\n", + " \u001b[32m0.7499999999999999\u001b[0m,\n", + " \u001b[32m1.0\u001b[0m,\n", + " \u001b[32m0.6\u001b[0m,\n", + " \u001b[33mArray\u001b[0m(\u001b[32m\"ham\"\u001b[0m, \u001b[32m\"spam\"\u001b[0m),\n", + " \u001b[32m5\u001b[0m,\n", + " \u001b[32m3\u001b[0m,\n", + " \u001b[32m3\u001b[0m\n", + " ),\n", + " \u001b[33mResultPerLabel\u001b[0m(\n", + " \u001b[32m\"spam\"\u001b[0m,\n", + " \u001b[32m0.8333333333333333\u001b[0m,\n", + " \u001b[32m0.7142857142857143\u001b[0m,\n", + " \u001b[32m1.0\u001b[0m,\n", + " \u001b[33mArray\u001b[0m(\u001b[32m\"ham\"\u001b[0m, \u001b[32m\"spam\"\u001b[0m),\n", + " \u001b[32m5\u001b[0m,\n", + " \u001b[32m7\u001b[0m,\n", + "\u001b[33m...\u001b[0m" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "val trainData = DataReader(trainDataPath + \"spam\", \"spam\") ++ DataReader(trainDataPath + \"ham\", \"ham\")\n", + "val testData = DataReader(testDataPath + \"spam\", \"spam\") ++ DataReader(testDataPath + \"ham\", \"ham\")\n", + "\n", + "SpamDataModel.email.populate(trainData)\n", + "SpamDataModel.email.populate(testData, train = false)\n", + "\n", + "SpamClassifier.learn(30)\n", + "SpamClassifier.test()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Scala 2.11", + "language": "scala211", + "name": "scala211" + }, + "language_info": { + "codemirror_mode": "text/x-scala", + "file_extension": ".scala", + "mimetype": "text/x-scala", + "name": "scala211", + "pygments_lexer": "scala", + "version": "2.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}