wasBayesSharp – Diff between revs 1 and 3

Subversion Repositories:
Rev:
Only display areas with differencesIgnore whitespace
Rev 1 Rev 3
1 using System; 1 using System;
2 using System.Collections.Generic; 2 using System.Collections.Generic;
3 using System.IO; 3 using System.IO;
4 using System.Linq; 4 using System.Linq;
5 using System.Text; 5 using System.Text;
6 using BayesSharp.Combiners; 6 using BayesSharp.Combiners;
7 using BayesSharp.Tokenizers; 7 using BayesSharp.Tokenizers;
8 using Newtonsoft.Json; 8 using Newtonsoft.Json;
9   9  
10 namespace BayesSharp 10 namespace BayesSharp
11 { 11 {
12 public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable 12 public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable
13 { 13 {
14 private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>(); 14 private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>();
15 private TagDictionary<TTokenType, TTagType> _cache; 15 private TagDictionary<TTokenType, TTagType> _cache;
16   16  
17 private readonly ITokenizer<TTokenType> _tokenizer; 17 private readonly ITokenizer<TTokenType> _tokenizer;
18 private readonly ICombiner _combiner; 18 private readonly ICombiner _combiner;
19   19  
20 private bool _mustRecache; 20 private bool _mustRecache;
21 private const double Tolerance = 0.0001; 21 private const double Tolerance = 0.0001;
22 private const double Threshold = 0.1; 22 private const double Threshold = 0.1;
23   23  
24 public BayesClassifier(ITokenizer<TTokenType> tokenizer) 24 public BayesClassifier(ITokenizer<TTokenType> tokenizer)
25 : this(tokenizer, new RobinsonCombiner()) 25 : this(tokenizer, new RobinsonCombiner())
26 { 26 {
27 } 27 }
28   28  
29 public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner) 29 public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner)
30 { 30 {
31 if (tokenizer == null) throw new ArgumentNullException("tokenizer"); 31 if (tokenizer == null) throw new ArgumentNullException("tokenizer");
32 if (combiner == null) throw new ArgumentNullException("combiner"); 32 if (combiner == null) throw new ArgumentNullException("combiner");
33   33  
34 _tokenizer = tokenizer; 34 _tokenizer = tokenizer;
35 _combiner = combiner; 35 _combiner = combiner;
36   36  
37 _tags.SystemTag = new TagData<TTokenType>(); 37 _tags.SystemTag = new TagData<TTokenType>();
38 _mustRecache = true; 38 _mustRecache = true;
39 } 39 }
40   40  
41 /// <summary> 41 /// <summary>
42 /// Create a new tag, without actually doing any training. 42 /// Create a new tag, without actually doing any training.
43 /// </summary> 43 /// </summary>
44 /// <param name="tagId">Tag Id</param> 44 /// <param name="tagId">Tag Id</param>
45 public void AddTag(TTagType tagId) 45 public void AddTag(TTagType tagId)
46 { 46 {
47 GetAndAddIfNotFound(_tags.Items, tagId); 47 GetAndAddIfNotFound(_tags.Items, tagId);
48 _mustRecache = true; 48 _mustRecache = true;
49 } 49 }
50   50  
51 /// <summary> 51 /// <summary>
52 /// Remove a tag 52 /// Remove a tag
53 /// </summary> 53 /// </summary>
54 /// <param name="tagId">Tag Id</param> 54 /// <param name="tagId">Tag Id</param>
55 public void RemoveTag(TTagType tagId) 55 public void RemoveTag(TTagType tagId)
56 { 56 {
57 _tags.Items.Remove(tagId); 57 _tags.Items.Remove(tagId);
58 _mustRecache = true; 58 _mustRecache = true;
59 } 59 }
60   60  
61 /// <summary> 61 /// <summary>
62 /// Change the Id of a tag 62 /// Change the Id of a tag
63 /// </summary> 63 /// </summary>
64 /// <param name="oldTagId">Old Tag Id</param> 64 /// <param name="oldTagId">Old Tag Id</param>
65 /// <param name="newTagId">New Tag Id</param> 65 /// <param name="newTagId">New Tag Id</param>
66 public void ChangeTagId(TTagType oldTagId, TTagType newTagId) 66 public void ChangeTagId(TTagType oldTagId, TTagType newTagId)
67 { 67 {
68 _tags.Items[newTagId] = _tags.Items[oldTagId]; 68 _tags.Items[newTagId] = _tags.Items[oldTagId];
69 RemoveTag(oldTagId); 69 RemoveTag(oldTagId);
70 _mustRecache = true; 70 _mustRecache = true;
71 } 71 }
72   72  
73 /// <summary> 73 /// <summary>
74 /// Merge an existing tag into another 74 /// Merge an existing tag into another
75 /// </summary> 75 /// </summary>
76 /// <param name="sourceTagId">Tag to merged to destTagId and removed</param> 76 /// <param name="sourceTagId">Tag to merged to destTagId and removed</param>
77 /// <param name="destTagId">Destination tag Id</param> 77 /// <param name="destTagId">Destination tag Id</param>
78 public void MergeTags(TTagType sourceTagId, TTagType destTagId) 78 public void MergeTags(TTagType sourceTagId, TTagType destTagId)
79 { 79 {
80 var sourceTag = _tags.Items[sourceTagId]; 80 var sourceTag = _tags.Items[sourceTagId];
81 var destTag = _tags.Items[destTagId]; 81 var destTag = _tags.Items[destTagId];
82 var count = 0; 82 var count = 0;
83 foreach (var tagItem in sourceTag.Items) 83 foreach (var tagItem in sourceTag.Items)
84 { 84 {
85 count++; 85 count++;
86 var tok = tagItem; 86 var tok = tagItem;
87 if (destTag.Items.ContainsKey(tok.Key)) 87 if (destTag.Items.ContainsKey(tok.Key))
88 { 88 {
89 destTag.Items[tok.Key] += count; 89 destTag.Items[tok.Key] += count;
90 } 90 }
91 else 91 else
92 { 92 {
93 destTag.Items[tok.Key] = count; 93 destTag.Items[tok.Key] = count;
94 destTag.TokenCount += 1; 94 destTag.TokenCount += 1;
95 } 95 }
96 } 96 }
97 RemoveTag(sourceTagId); 97 RemoveTag(sourceTagId);
98 _mustRecache = true; 98 _mustRecache = true;
99 } 99 }
100   100  
101 /// <summary> 101 /// <summary>
102 /// Return a TagData object of a Tag Id informed 102 /// Return a TagData object of a Tag Id informed
103 /// </summary> 103 /// </summary>
104 /// <param name="tagId">Tag Id</param> 104 /// <param name="tagId">Tag Id</param>
105 public TagData<TTokenType> GetTagById(TTagType tagId) 105 public TagData<TTokenType> GetTagById(TTagType tagId)
106 { 106 {
107 return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null; 107 return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null;
108 } 108 }
109   109  
110 /// <summary> 110 /// <summary>
111 /// Save Bayes Text Classifier into a file 111 /// Save Bayes Text Classifier into a file
112 /// </summary> 112 /// </summary>
113 /// <param name="path">The file to write to</param> 113 /// <param name="path">The file to write to</param>
114 public void Save(string path) 114 public void Save(string path)
115 { 115 {
116 using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8)) 116 using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8))
117 { 117 {
118 JsonSerializer.Create().Serialize(streamWriter, _tags); 118 JsonSerializer.Create().Serialize(streamWriter, _tags);
119 } 119 }
120 } 120 }
121   121  
122 /// <summary> 122 /// <summary>
123 /// Load Bayes Text Classifier from a file 123 /// Load Bayes Text Classifier from a file
124 /// </summary> 124 /// </summary>
125 /// <param name="path">The file to open for reading</param> 125 /// <param name="path">The file to open for reading</param>
126 public void Load(string path) 126 public void Load(string path)
127 { 127 {
128 using (var streamReader = new StreamReader(path, Encoding.UTF8)) 128 using (var streamReader = new StreamReader(path, Encoding.UTF8))
129 { 129 {
130 using (var jsonTextReader = new JsonTextReader(streamReader)) 130 using (var jsonTextReader = new JsonTextReader(streamReader))
131 { 131 {
132 _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader); 132 _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader);
133 } 133 }
134 } 134 }
135 _mustRecache = true; 135 _mustRecache = true;
136 } 136 }
137   137  
138 /// <summary> 138 /// <summary>
139 /// Import Bayes Text Classifier from a json string 139 /// Import Bayes Text Classifier from a json string
140 /// </summary> 140 /// </summary>
141 /// <param name="json">The json content to be loaded</param> 141 /// <param name="json">The json content to be loaded</param>
142 public void ImportJsonData(string json) 142 public void ImportJsonData(string json)
143 { 143 {
144 _tags = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json); 144 var result = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json);
-   145 switch (result != null)
-   146 {
-   147 case true:
-   148 _tags = result;
145 _mustRecache = true; 149 _mustRecache = true;
-   150 break;
-   151  
-   152 default:
-   153 _tags = new TagDictionary<TTokenType, TTagType>();
-   154 break;
-   155 }
146 } 156 }
147   157  
148 /// <summary> 158 /// <summary>
149 /// Export Bayes Text Classifier to a json string 159 /// Export Bayes Text Classifier to a json string
150 /// </summary> 160 /// </summary>
151 public string ExportJsonData() 161 public string ExportJsonData()
152 { 162 {
-   163 return _tags?.Items != null &&
-   164 _tags.Items.Any() ?
153 return JsonConvert.SerializeObject(_tags); 165 JsonConvert.SerializeObject(_tags) :
-   166 string.Empty;
154 } 167 }
155   168  
156 /// <summary> 169 /// <summary>
157 /// Return a sorted list of Tag Ids 170 /// Return a sorted list of Tag Ids
158 /// </summary> 171 /// </summary>
159 public IEnumerable<TTagType> TagIds() 172 public IEnumerable<TTagType> TagIds()
160 { 173 {
161 return _tags.Items.Keys.OrderBy(p => p); 174 return _tags.Items.Keys.OrderBy(p => p);
162 } 175 }
163   176  
164 /// <summary> 177 /// <summary>
165 /// Train Bayes by telling him that input belongs in tag. 178 /// Train Bayes by telling him that input belongs in tag.
166 /// </summary> 179 /// </summary>
167 /// <param name="tagId">Tag Id</param> 180 /// <param name="tagId">Tag Id</param>
168 /// <param name="input">Input to be trained</param> 181 /// <param name="input">Input to be trained</param>
169 public void Train(TTagType tagId, string input) 182 public void Train(TTagType tagId, string input)
170 { 183 {
171 var tokens = _tokenizer.Tokenize(input); 184 var tokens = _tokenizer.Tokenize(input);
172 var tag = GetAndAddIfNotFound(_tags.Items, tagId); 185 var tag = GetAndAddIfNotFound(_tags.Items, tagId);
173 _train(tag, tokens); 186 _train(tag, tokens);
174 _tags.SystemTag.TrainCount += 1; 187 _tags.SystemTag.TrainCount += 1;
175 tag.TrainCount += 1; 188 tag.TrainCount += 1;
176 _mustRecache = true; 189 _mustRecache = true;
177 } 190 }
178   191  
179 /// <summary> 192 /// <summary>
180 /// Untrain Bayes by telling him that input no more belongs in tag. 193 /// Untrain Bayes by telling him that input no more belongs in tag.
181 /// </summary> 194 /// </summary>
182 /// <param name="tagId">Tag Id</param> 195 /// <param name="tagId">Tag Id</param>
183 /// <param name="input">Input to be untrained</param> 196 /// <param name="input">Input to be untrained</param>
184 public void Untrain(TTagType tagId, string input) 197 public void Untrain(TTagType tagId, string input)
185 { 198 {
186 var tokens = _tokenizer.Tokenize(input); 199 var tokens = _tokenizer.Tokenize(input);
187 var tag = _tags.Items[tagId]; 200 var tag = _tags.Items[tagId];
188 if (tag == null) 201 if (tag == null)
189 { 202 {
190 return; 203 return;
191 } 204 }
192 _untrain(tag, tokens); 205 _untrain(tag, tokens);
193 _tags.SystemTag.TrainCount += 1; 206 _tags.SystemTag.TrainCount += 1;
194 tag.TrainCount += 1; 207 tag.TrainCount += 1;
195 _mustRecache = true; 208 _mustRecache = true;
196 } 209 }
197   210  
198 /// <summary> 211 /// <summary>
199 /// Returns the scores in each tag the provided input 212 /// Returns the scores in each tag the provided input
200 /// </summary> 213 /// </summary>
201 /// <param name="input">Input to be classified</param> 214 /// <param name="input">Input to be classified</param>
202 public Dictionary<TTagType, double> Classify(string input) 215 public Dictionary<TTagType, double> Classify(string input)
203 { 216 {
204 var tokens = _tokenizer.Tokenize(input).ToList(); 217 var tokens = _tokenizer.Tokenize(input).ToList();
205 var tags = CreateCacheAnsGetTags(); 218 var tags = CreateCacheAnsGetTags();
206   219  
207 var stats = new Dictionary<TTagType, double>(); 220 var stats = new Dictionary<TTagType, double>();
208   221  
209 foreach (var tag in tags.Items) 222 foreach (var tag in tags.Items)
210 { 223 {
211 var probs = GetProbabilities(tag.Value, tokens).ToList(); 224 var probs = GetProbabilities(tag.Value, tokens).ToList();
212 if (probs.Count() != 0) 225 if (probs.Count() != 0)
213 { 226 {
214 stats[tag.Key] = _combiner.Combine(probs); 227 stats[tag.Key] = _combiner.Combine(probs);
215 } 228 }
216 } 229 }
217 return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value); 230 return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value);
218 } 231 }
219   232  
220 #region Private Methods 233 #region Private Methods
221   234  
222 private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) 235 private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
223 { 236 {
224 var tokenCount = 0; 237 var tokenCount = 0;
225 foreach (var token in tokens) 238 foreach (var token in tokens)
226 { 239 {
227 var count = tag.Get(token, 0); 240 var count = tag.Get(token, 0);
228 tag.Items[token] = count + 1; 241 tag.Items[token] = count + 1;
229 count = _tags.SystemTag.Get(token, 0); 242 count = _tags.SystemTag.Get(token, 0);
230 _tags.SystemTag.Items[token] = count + 1; 243 _tags.SystemTag.Items[token] = count + 1;
231 tokenCount += 1; 244 tokenCount += 1;
232 } 245 }
233 tag.TokenCount += tokenCount; 246 tag.TokenCount += tokenCount;
234 _tags.SystemTag.TokenCount += tokenCount; 247 _tags.SystemTag.TokenCount += tokenCount;
235 } 248 }
236   249  
237 private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) 250 private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
238 { 251 {
239 foreach (var token in tokens) 252 foreach (var token in tokens)
240 { 253 {
241 var count = tag.Get(token, 0); 254 var count = tag.Get(token, 0);
242 if (count > 0) 255 if (count > 0)
243 { 256 {
244 if (Math.Abs(count - 1) < Tolerance) 257 if (Math.Abs(count - 1) < Tolerance)
245 { 258 {
246 tag.Items.Remove(token); 259 tag.Items.Remove(token);
247 } 260 }
248 else 261 else
249 { 262 {
250 tag.Items[token] = count - 1; 263 tag.Items[token] = count - 1;
251 } 264 }
252 tag.TokenCount -= 1; 265 tag.TokenCount -= 1;
253 } 266 }
254 count = _tags.SystemTag.Get(token, 0); 267 count = _tags.SystemTag.Get(token, 0);
255 if (count > 0) 268 if (count > 0)
256 { 269 {
257 if (Math.Abs(count - 1) < Tolerance) 270 if (Math.Abs(count - 1) < Tolerance)
258 { 271 {
259 _tags.SystemTag.Items.Remove(token); 272 _tags.SystemTag.Items.Remove(token);
260 } 273 }
261 else 274 else
262 { 275 {
263 _tags.SystemTag.Items[token] = count - 1; 276 _tags.SystemTag.Items[token] = count - 1;
264 } 277 }
265 _tags.SystemTag.TokenCount -= 1; 278 _tags.SystemTag.TokenCount -= 1;
266 } 279 }
267 } 280 }
268 } 281 }
269   282  
270 private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key) 283 private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key)
271 { 284 {
272 if (dic.ContainsKey(key)) 285 if (dic.ContainsKey(key))
273 { 286 {
274 return dic[key]; 287 return dic[key];
275 } 288 }
276 dic[key] = new TagData<TTokenType>(); 289 dic[key] = new TagData<TTokenType>();
277 return dic[key]; 290 return dic[key];
278 } 291 }
279   292  
280 private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags() 293 private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags()
281 { 294 {
282 if (!_mustRecache) return _cache; 295 if (!_mustRecache) return _cache;
283   296  
284 _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag }; 297 _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag };
285 foreach (var tag in _tags.Items) 298 foreach (var tag in _tags.Items)
286 { 299 {
287 var thisTagTokenCount = tag.Value.TokenCount; 300 var thisTagTokenCount = tag.Value.TokenCount;
288 var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1); 301 var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1);
289 var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key); 302 var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key);
290   303  
291 foreach (var systemTagItem in _tags.SystemTag.Items) 304 foreach (var systemTagItem in _tags.SystemTag.Items)
292 { 305 {
293 var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0); 306 var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0);
294 if (Math.Abs(thisTagTokenFreq) < Tolerance) 307 if (Math.Abs(thisTagTokenFreq) < Tolerance)
295 { 308 {
296 continue; 309 continue;
297 } 310 }
298 var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq; 311 var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq;
299   312  
300 var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount); 313 var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount);
301 var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount); 314 var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount);
302 var f = badMetric / (goodMetric + badMetric); 315 var f = badMetric / (goodMetric + badMetric);
303   316  
304 if (Math.Abs(f - 0.5) >= Threshold) 317 if (Math.Abs(f - 0.5) >= Threshold)
305 { 318 {
306 cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f)); 319 cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f));
307 } 320 }
308 } 321 }
309 } 322 }
310 _mustRecache = false; 323 _mustRecache = false;
311 return _cache; 324 return _cache;
312 } 325 }
313   326  
314 private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) 327 private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
315 { 328 {
316 var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]); 329 var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]);
317 return probs.OrderByDescending(p => p).Take(2048); 330 return probs.OrderByDescending(p => p).Take(2048);
318 } 331 }
319   332  
320 #endregion -  
321   333 #endregion Private Methods
322 } -  
323 } 334 }
-   335 }
324   336  
325
Generated by GNU Enscript 1.6.5.90.
-  
326   -  
327   -  
328   -