Sent to you by jeffye via Google Reader:
Something came up at work recently that sparked off my interest in this stuff again, and it also meshes nicely with my objective of working through my TMAP book, so this week, I decided to explore building a binary Naive Bayes Classifier with Lucene. This post is a result of that exploration.
The math behind Naive Bayesian Classifiers is explained in detail here, but I needed to work backward through it to figure out what values to collect and use for training the classifier and to do the actual classification. So...
For classification purposes, document D is in category C if: p(C|D) r = --------- > 1 p(¬C|D) where: r = likelihood ratio p(C|D) = probability that category is C given document D p(¬C|D) = probability that category is not C given document D We can rewrite p(C|D) as follows: p(C|D) = p(D ∩ C) / p(D) ... by probability axiom = p(C) * p(D|C) / p(D) ... by Bayes theorem Similarly, we can rewrite p(C|¬D): p(¬C|D) = p(¬C) * p(D|¬C) / p(D) Here p(C) and p(¬C) can be represented in terms of the number of documents in each category C and ¬C divided by the total number of documents in the training set. p(C) = μ(C) / n p(¬C) = μ(¬C) / n where: μ(C) = number of docs in category C in training set. μ(¬C) = number of docs not in category C in training set. n = number of docs in the training set. If the words in the documents in the training set are represented by the set {w0, w1, ..., wk}, then: p(D|C) = Πi=0..k p(wi|C) p(D|¬C) = Πi=0..k p(wi|¬C) So our likelihood ratio can be re-written as: μ(C) Πi=0..k p(wi ∩ C) r = (-------) * (----------------------) μ(¬C) Πi=0..k p(wii ∩ ¬C)
So, during our training phase, we need to compute the ratio of number of documents in C vs ¬C, and for each individual word, we need to find the probability of the word in category C and not C. The probability of individual words can be found by dividing the frequency of a word in each category by the total number of words in the training set.
My classifier uses Lucene, so it expects an index of pre-classified documents as its input, as well as the field name that contains the classification information. Because it is a binary classifier, it also needs to know the class that should be treated as positive (ie, C). Here is the code for the classifier - we show the training portion here.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | // Source: src/main/java/com/mycompany/myapp/classifiers/LuceneNaiveBayesClassifier.java package com.mycompany.myapp.classifiers; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Index; import org.apache.lucene.document.Field.Store; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermDocs; import org.apache.lucene.index.TermEnum; import org.apache.lucene.index.IndexWriter.MaxFieldLength; import org.apache.lucene.search.CachingWrapperFilter; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Filter; import org.apache.lucene.search.QueryWrapperFilter; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.RAMDirectory; import com.mycompany.myapp.summarizers.SummaryAnalyzer; public class LuceneNaiveBayesClassifier { private final Log log = LogFactory.getLog(getClass()); private String indexDir; private String categoryFieldName; private String matchCategoryValue; private boolean selectTopFeatures = false; private boolean preventOverfitting = false; private Analyzer analyzer = new StandardAnalyzer(); private static final double ALMOST_ZERO_PROBABILITY = 0.00001D; private Map<String,double[]> trainingSet; private double categoryDocRatio; public void setIndexDir(String indexDir) { this.indexDir = indexDir; } public void setCategoryFieldName(String categoryFieldName) { this.categoryFieldName = categoryFieldName; } public void setMatchCategoryValue(String matchCategoryValue) { this.matchCategoryValue = matchCategoryValue; } public void setSelectTopFeatures(boolean selectTopFeatures) { this.selectTopFeatures = selectTopFeatures; } public void setPreventOverfitting(boolean preventOverfitting) { this.preventOverfitting = preventOverfitting; } public void setAnalyzer(Analyzer analyzer) { this.analyzer = analyzer; } /** * Creates an array of terms and their positive and negative probabilities * and the ratio of documents in a certain category. Expects a Lucene * index created with the tokenized document bodies, and a category * field that is specified in the setters and populated with the specified * category value. * @throws Exception if one is thrown. */ public void train() throws Exception { this.trainingSet = new HashMap<String,double[]>(); IndexReader reader = null; try { reader = IndexReader.open(FSDirectory.getDirectory(indexDir)); Set<Integer> matchedDocIds = computeMatchedDocIds(reader); double matchedDocs = (double) matchedDocIds.size(); double nDocs = (double) reader.numDocs(); this.categoryDocRatio = matchedDocs / (nDocs - matchedDocs); TermEnum termEnum = reader.terms(); double nWords = 0.0D; double nUniqueWords = 0.0D; while (termEnum.next()) { double nWordInCategory = 0.0D; double nWordNotInCategory = 0.0D; Term term = termEnum.term(); TermDocs termDocs = reader.termDocs(term); while (termDocs.next()) { int docId = termDocs.doc(); int frequency = termDocs.freq(); if (matchedDocIds.contains(docId)) { nWordInCategory += frequency; } else { nWordNotInCategory += frequency; } nWords += frequency; nUniqueWords++; } double[] pWord = new double[2]; if (trainingSet.containsKey(term.text())) { pWord = trainingSet.get(term.text()); } pWord[0] += (double) nWordInCategory; pWord[1] += (double) nWordNotInCategory; trainingSet.put(term.text(), pWord); } // once we have gone through all our terms, we normalize our // trainingSet so the values are probabilities, not numbers for (String term : trainingSet.keySet()) { double[] pWord = trainingSet.get(term); for (int i = 0; i < pWord.length; i++) { if (preventOverfitting) { // apply smoothening formula pWord[i] = ((pWord[i] + 1) / (nWords + nUniqueWords)); } else { pWord[i] /= nWords; } } } if (selectTopFeatures) { InfoGainFeatureSelector featureSelector = new InfoGainFeatureSelector(); featureSelector.setWordProbabilities(trainingSet); featureSelector.setPCategory(matchedDocs / nDocs); Map<String,double[]> topFeatures = featureSelector.selectFeatures(); this.trainingSet = topFeatures; } } finally { if (reader != null) { reader.close(); } } } public Map<String,double[]> getTrainingSet() { return trainingSet; } public double getCategoryDocRatio() { return this.categoryDocRatio; } ... private Set<Integer> computeMatchedDocIds(IndexReader reader) throws IOException { Filter categoryFilter = new CachingWrapperFilter( new QueryWrapperFilter(new TermQuery( new Term(categoryFieldName, matchCategoryValue)))); DocIdSet docIdSet = categoryFilter.getDocIdSet(reader); DocIdSetIterator docIdSetIterator = docIdSet.iterator(); Set<Integer> matchedDocIds = new HashSet<Integer>(); while (docIdSetIterator.next()) { matchedDocIds.add(docIdSetIterator.doc()); } return matchedDocIds; } } |
The caller sets the index directory, the category name and value, and the other parameters (which are explained in more detail later), then calls the train() method. Once the training is complete, the word probabilities and the category document ratio are available via getters. This models how one would normally do this in real-life, the training phase is typically time and resource-intensive, but the training data can now be used for any number of classification calls.
In the real world, though, we would probably be dealing with far larger volumes of pre-classified documents, so using an in-memory Map<String,double[]> probably wouldn't scale too well. In that case, we may want to think about using a database table to store these values. I tried that an earlier attempt using the Classifier4J Naive Bayes classifier, so you may want to check it out if you are interested in going that route.
In the classification phase, we feed in a body of text, and for each word that is found in the training set, multiply its p(w|C)/p(w|¬C) value into the likelihood ratio. Notice that the frequency of the word in the input document does not matter, since they get cancelled out when calculating the ratio. Finally we multiply in the document ratio to get the likelihood. Here is the classify() method - its part of the same class as above, I have separated it out for purposes of explanation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | // Source: src/main/java/com/mycompany/myapp/classifiers/LuceneNaiveBayesClassifier.java // (cont'd from above) ... public class LuceneNaiveBayesClassifier { ... public boolean classify(Map<String,double[]> wordProbabilities, double categoryDocRatio, String text) throws Exception { RAMDirectory ramdir = new RAMDirectory(); IndexWriter writer = null; IndexReader reader = null; try { writer = new IndexWriter(ramdir, analyzer, MaxFieldLength.UNLIMITED); Document doc = new Document(); doc.add(new Field("text", text, Store.NO, Index.ANALYZED)); writer.addDocument(doc); writer.commit(); writer.close(); double likelihoodRatio = categoryDocRatio; reader = IndexReader.open(ramdir); TermEnum termEnum = reader.terms(); while (termEnum.next()) { Term term = termEnum.term(); TermDocs termDocs = reader.termDocs(term); String word = term.text(); if (trainingSet.containsKey(word)) { // we don't care about the frequency since they cancel out // when computing p(w|C) and p(w|-C) for the same number of w double[] probabilities = trainingSet.get(word); if (probabilities[1] == 0.0D) { // this means that the word is a very good discriminator word, // we put in an artificially low value instead of 0 (preventing // a divide by 0) and keeping the term likelihoodRatio *= (probabilities[0] / ALMOST_ZERO_PROBABILITY); } else { likelihoodRatio *= (probabilities[0] / probabilities[1]); } } } return (likelihoodRatio > 1.0D); } finally { if (writer != null && IndexWriter.isLocked(ramdir)) { IndexWriter.unlock(ramdir); writer.rollback(); writer.close(); } if (reader != null) { reader.close(); } } } } |
New TokenFilter to remove numbers
The training data (see below for more information about the data) I used contained a lot of numbers which did not seem to have much discriminatory value, so I decided to filter it out from the words the classifier was being trained with. It was also a good excuse to learn how to build a Lucene TokenFilter :-), which is shown below. As before, Marcus Tripp's blog post was very helpful.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | // Source: src/main/java/com/mycompany/myapp/tokenizers/lucene/NumericTokenFilter.java package com.mycompany.myapp.tokenizers.lucene; import java.io.IOException; import org.apache.commons.lang.math.NumberUtils; import org.apache.lucene.analysis.Token; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; /** * Filters out numeric tokens from the TokenStream. */ public class NumericTokenFilter extends TokenFilter { public NumericTokenFilter(TokenStream input) { super(input); } @Override public Token next(Token token) throws IOException { while ((token = input.next(token)) != null) { String term = token.term(); term = term.replaceAll(",", ""); if (! NumberUtils.isNumber(term)) { return token; } } return null; } } |
I should probably have built a different analyzer for this purpose, but I figured that numbers are a good thing to remove in most cases anyway, so I just stuck it into the SummaryAnalyzer I described in my post last week. The tokenStream() method is the only one that changed, so I show it here:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | // Source: src/main/java/com/mycompany/myapp/summarizers/SummaryAnalyzer.java // (changed since last week, only changed method shown) ... public class SummaryAnalyzer extends Analyzer { ... @Override public TokenStream tokenStream(String fieldName, Reader reader) { return new PorterStemFilter( new StopFilter( new LowerCaseFilter( new NumericTokenFilter( new StandardFilter( new StandardTokenizer(reader)))), stopset)); } ... } |
You can hook up other Analyzer implementations by calling the setAnalyzer() method of LuceneNaiveBayesClassifier. The default (if not specified) is StandardAnalyzer.
Feature Selection with Information Gain
If you look at the results, you will see that in the first test, we used all the words we found (after tokenization) for our probability calculations. This can sometimes lead to misleading results, since there may be large number of words with less discriminatory power (ie, those which occur with similar frequency in both categories). To remedy this, we compute the information gain for each word, and consider only the √k words with the highest information gain for classification, where k is the total number of words.
Information gain is computed using the following formula: p(wi|C) I(wi) = p(wi|C) * log(---------------) p(wi) * p(C)
We can enable feature selection in our LuceneNaiveBayesClassifier by setting the selectTopFeatures property to true. This will invoke the InfoGainFeatureSelector shown below:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | // Source: src/main/java/com/mycompany/myapp/classifiers/InfoGainFeatureSelector.java package com.mycompany.myapp.classifiers; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; public class InfoGainFeatureSelector { private double pCategory; private List<Word> wordProbabilities; public void setPCategory(double pCategory) { this.pCategory = pCategory; } public void setWordProbabilities( Map<String,double[]> wordProbabilities) { this.wordProbabilities = new ArrayList<Word>(); for (String word : wordProbabilities.keySet()) { double[] probabilities = wordProbabilities.get(word); this.wordProbabilities.add( new Word(word, probabilities[0], probabilities[1])); } } public Map<String,double[]> selectFeatures() throws Exception { for (Word word : wordProbabilities) { if (word.pInCat > 0.0D) { word.infoGain = word.pInCat * Math.log( word.pInCat / ((word.pInCat + word.pNotInCat) * pCategory)); } else { word.infoGain = 0.0D; } } Collections.sort(wordProbabilities); List<Word> topFeaturesList = wordProbabilities.subList( 0, (int) Math.round(Math.sqrt(wordProbabilities.size()))); Map<String,double[]> topFeatures = new HashMap<String,double[]>(); for (Word topFeature : topFeaturesList) { topFeatures.put(topFeature.term, new double[] {topFeature.pInCat, topFeature.pNotInCat}); } return topFeatures; } private class Word implements Comparable<Word> { private String term; private double pInCat; private double pNotInCat; public double infoGain; public Word(String term, double pInCat, double pNotInCat) { this.term = term; this.pInCat = pInCat; this.pNotInCat = pNotInCat; } public int compareTo(Word o) { if (infoGain == o.infoGain) { return 0; } else { return infoGain > o.infoGain ? -1 : 1; } } public String toString() { return term + "(" + pInCat + "," + pNotInCat + ")=" + infoGain; } } } |
Prevent Overfitting
When our training data is small, we may add some numbers to the numerator and denominator when calculating the word probability, in order to make the word probability distribution smoother. Specifically, we add 1 to the numerator and k to the denominator, where k is the number of unique words in our training set. Thus:
p(wi|C) = (μ(wi|C) + 1) / (n + k)) p(wi|¬C) = (μ(wi|¬C) + 1) / (n + k)) where: n = number of words in training set. k = number of unique words in training set.
This can be turned on in the LuceneNaiveBayesClassifier by setting the preventOverfitting property to true. In my (admittedly limited) testing, I did not see any changes in results after doing this, however.
Test code and data
The training data I used is the set of 54 files from Reuters that is hardcoded in the cluster.pl file in the TextMine project (on which the TMAP book is based). It is categorized into three categories - "cocoa", "coffee" and "sugar". The information in these files seem to be primarily aimed at commodity market investors or people within the respective industries.
For classification, I choose some files such as this file on cocoa and this file on coffee from the Reuter's site. I train the classifier with the cocoa and "not cocoa" (ie the coffee and sugar) documents, then try to classify some cocoa documents and some coffee documents.
Here is my JUnit test. As you can see, the test builds the index out of the input file before the test (@BeforeClass) and deletes it after (@AfterClass). Four tests are run, each with different settings of featureSelection and overfitting prevention, and each test attempts to analyze 5 documents (3 cocoa and 2 coffee).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | // Source: src/test/java/com/mycompany/myapp/classifiers/ClassifierTest.java package com.mycompany.myapp.classifiers; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Index; import org.apache.lucene.document.Field.Store; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter.MaxFieldLength; import org.apache.lucene.store.FSDirectory; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import com.mycompany.myapp.summarizers.SummaryAnalyzer; public class ClassifierTest { private static final Log log = LogFactory.getLog(ClassifierTest.class); private static String INPUT_FILE = "src/test/resources/data/sugar-coffee-cocoa-docs.txt"; private static String INDEX_DIR = "src/test/resources/data/scc-index"; private static String[] DOCS_TO_CLASSIFY = new String[] { "src/test/resources/data/cocoa.txt", "src/test/resources/data/cocoa1.txt", "src/test/resources/data/cocoa2.txt", "src/test/resources/data/coffee.txt", "src/test/resources/data/coffee1.txt" }; @BeforeClass public static void buildIndex() throws Exception { BufferedReader reader = new BufferedReader(new FileReader(INPUT_FILE)); IndexWriter writer = new IndexWriter( FSDirectory.getDirectory(INDEX_DIR), new SummaryAnalyzer(), MaxFieldLength.UNLIMITED); String line = null; int lno = 0; StringBuilder bodybuf = new StringBuilder(); String category = null; while ((line = reader.readLine()) != null) { if (line.endsWith(".sgm")) { // header line if (lno > 0) { // not the very first line, so dump current body buffer and // reinit the buffer. writeToIndex(writer, category, bodybuf.toString()); bodybuf = new StringBuilder(); } category = StringUtils.trim(StringUtils.split(line, ":")[1]); continue; } else { // not a header line, accumulate line into bodybuf bodybuf.append(line).append(" "); } lno++; } // last record writeToIndex(writer, category, bodybuf.toString()); reader.close(); writer.commit(); writer.optimize(); writer.close(); } private static void writeToIndex(IndexWriter writer, String category, String body) throws Exception { Document doc = new Document(); doc.add(new Field("category", category, Store.YES, Index.NOT_ANALYZED)); doc.add(new Field("body", body, Store.NO, Index.ANALYZED)); writer.addDocument(doc); } @AfterClass public static void deleteIndex() throws Exception { log.info("Deleting index directory..."); FileUtils.deleteDirectory(new File(INDEX_DIR)); } @Test public void testLuceneNaiveBayesClassifier() throws Exception { LuceneNaiveBayesClassifier classifier = train(false, false); categorize(classifier, DOCS_TO_CLASSIFY); } @Test public void testLuceneNaiveBayesClassifier2() throws Exception { LuceneNaiveBayesClassifier classifier = train(true, false); categorize(classifier, DOCS_TO_CLASSIFY); } @Test public void testLuceneNaiveBayesClassifier3() throws Exception { LuceneNaiveBayesClassifier classifier = train(true, true); categorize(classifier, DOCS_TO_CLASSIFY); } @Test public void testLuceneNaiveBayesClassifier4() throws Exception { LuceneNaiveBayesClassifier classifier = train(false, true); categorize(classifier, DOCS_TO_CLASSIFY); } private LuceneNaiveBayesClassifier train(boolean enableFeatureSelection, boolean preventOverfitting) throws Exception { System.out.println(">>> Training (featureSelection=" + enableFeatureSelection + ", preventOverfitting=" + preventOverfitting + ")"); LuceneNaiveBayesClassifier classifier = new LuceneNaiveBayesClassifier(); classifier.setIndexDir(INDEX_DIR); classifier.setCategoryFieldName("category"); classifier.setMatchCategoryValue("cocoa"); classifier.setSelectTopFeatures(enableFeatureSelection); classifier.setPreventOverfitting(preventOverfitting); classifier.setAnalyzer(new SummaryAnalyzer()); classifier.train(); return classifier; } private void categorize(LuceneNaiveBayesClassifier classifier, String[] testDocs) throws Exception { Map<String,double[]> trainingSet = classifier.getTrainingSet(); double categoryDocRatio = classifier.getCategoryDocRatio(); // classify new document for (String testDoc : testDocs) { File f = new File(testDoc); boolean isCocoa = classifier.classify(trainingSet, categoryDocRatio, FileUtils.readFileToString(f, "UTF-8")); System.out.println(">>> File: " + f.getName() + " in category:'cocoa'? " + isCocoa); } } } |
The results of the test are shown below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | Training (featureSelection=false, preventOverfitting=false) File: cocoa.txt in category:'cocoa'? true File: cocoa1.txt in category:'cocoa'? false File: cocoa2.txt in category:'cocoa'? false File: coffee.txt in category:'cocoa'? false File: coffee1.txt in category:'cocoa'? false Training (featureSelection=true, preventOverfitting=false) File: cocoa.txt in category:'cocoa'? true File: cocoa1.txt in category:'cocoa'? true File: cocoa2.txt in category:'cocoa'? true File: coffee.txt in category:'cocoa'? false File: coffee1.txt in category:'cocoa'? false Training (featureSelection=true, preventOverfitting=true) File: cocoa.txt in category:'cocoa'? true File: cocoa1.txt in category:'cocoa'? true File: cocoa2.txt in category:'cocoa'? true File: coffee.txt in category:'cocoa'? false File: coffee1.txt in category:'cocoa'? false Training (featureSelection=false, preventOverfitting=true) File: cocoa.txt in category:'cocoa'? true File: cocoa1.txt in category:'cocoa'? false File: cocoa2.txt in category:'cocoa'? false File: coffee.txt in category:'cocoa'? false File: coffee1.txt in category:'cocoa'? false |
As you can see, at least for this data set, feature selection was necessary for it to classify all the test documents correctly (tests 2 and 3). Overfitting prevention did not seem to have any effect in these tests.
Things you can do from here:
- Subscribe to Salmon Run using Google Reader
- Get started using Google Reader to easily keep up with all your favorite sites
1 Comments:
What a great resource!
Post a Comment