wasBayesSharp – Blame information for rev 1

Subversion Repositories:
Rev:
Rev Author Line No. Line
1 office 1 using System;
2 using System.Collections.Generic;
3 using System.IO;
4 using System.Linq;
5 using System.Text;
6 using BayesSharp.Combiners;
7 using BayesSharp.Tokenizers;
8 using Newtonsoft.Json;
9  
10 namespace BayesSharp
11 {
12 public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable
13 {
14 private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>();
15 private TagDictionary<TTokenType, TTagType> _cache;
16  
17 private readonly ITokenizer<TTokenType> _tokenizer;
18 private readonly ICombiner _combiner;
19  
20 private bool _mustRecache;
21 private const double Tolerance = 0.0001;
22 private const double Threshold = 0.1;
23  
24 public BayesClassifier(ITokenizer<TTokenType> tokenizer)
25 : this(tokenizer, new RobinsonCombiner())
26 {
27 }
28  
29 public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner)
30 {
31 if (tokenizer == null) throw new ArgumentNullException("tokenizer");
32 if (combiner == null) throw new ArgumentNullException("combiner");
33  
34 _tokenizer = tokenizer;
35 _combiner = combiner;
36  
37 _tags.SystemTag = new TagData<TTokenType>();
38 _mustRecache = true;
39 }
40  
41 /// <summary>
42 /// Create a new tag, without actually doing any training.
43 /// </summary>
44 /// <param name="tagId">Tag Id</param>
45 public void AddTag(TTagType tagId)
46 {
47 GetAndAddIfNotFound(_tags.Items, tagId);
48 _mustRecache = true;
49 }
50  
51 /// <summary>
52 /// Remove a tag
53 /// </summary>
54 /// <param name="tagId">Tag Id</param>
55 public void RemoveTag(TTagType tagId)
56 {
57 _tags.Items.Remove(tagId);
58 _mustRecache = true;
59 }
60  
61 /// <summary>
62 /// Change the Id of a tag
63 /// </summary>
64 /// <param name="oldTagId">Old Tag Id</param>
65 /// <param name="newTagId">New Tag Id</param>
66 public void ChangeTagId(TTagType oldTagId, TTagType newTagId)
67 {
68 _tags.Items[newTagId] = _tags.Items[oldTagId];
69 RemoveTag(oldTagId);
70 _mustRecache = true;
71 }
72  
73 /// <summary>
74 /// Merge an existing tag into another
75 /// </summary>
76 /// <param name="sourceTagId">Tag to merged to destTagId and removed</param>
77 /// <param name="destTagId">Destination tag Id</param>
78 public void MergeTags(TTagType sourceTagId, TTagType destTagId)
79 {
80 var sourceTag = _tags.Items[sourceTagId];
81 var destTag = _tags.Items[destTagId];
82 var count = 0;
83 foreach (var tagItem in sourceTag.Items)
84 {
85 count++;
86 var tok = tagItem;
87 if (destTag.Items.ContainsKey(tok.Key))
88 {
89 destTag.Items[tok.Key] += count;
90 }
91 else
92 {
93 destTag.Items[tok.Key] = count;
94 destTag.TokenCount += 1;
95 }
96 }
97 RemoveTag(sourceTagId);
98 _mustRecache = true;
99 }
100  
101 /// <summary>
102 /// Return a TagData object of a Tag Id informed
103 /// </summary>
104 /// <param name="tagId">Tag Id</param>
105 public TagData<TTokenType> GetTagById(TTagType tagId)
106 {
107 return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null;
108 }
109  
110 /// <summary>
111 /// Save Bayes Text Classifier into a file
112 /// </summary>
113 /// <param name="path">The file to write to</param>
114 public void Save(string path)
115 {
116 using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8))
117 {
118 JsonSerializer.Create().Serialize(streamWriter, _tags);
119 }
120 }
121  
122 /// <summary>
123 /// Load Bayes Text Classifier from a file
124 /// </summary>
125 /// <param name="path">The file to open for reading</param>
126 public void Load(string path)
127 {
128 using (var streamReader = new StreamReader(path, Encoding.UTF8))
129 {
130 using (var jsonTextReader = new JsonTextReader(streamReader))
131 {
132 _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader);
133 }
134 }
135 _mustRecache = true;
136 }
137  
138 /// <summary>
139 /// Import Bayes Text Classifier from a json string
140 /// </summary>
141 /// <param name="json">The json content to be loaded</param>
142 public void ImportJsonData(string json)
143 {
144 _tags = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json);
145 _mustRecache = true;
146 }
147  
148 /// <summary>
149 /// Export Bayes Text Classifier to a json string
150 /// </summary>
151 public string ExportJsonData()
152 {
153 return JsonConvert.SerializeObject(_tags);
154 }
155  
156 /// <summary>
157 /// Return a sorted list of Tag Ids
158 /// </summary>
159 public IEnumerable<TTagType> TagIds()
160 {
161 return _tags.Items.Keys.OrderBy(p => p);
162 }
163  
164 /// <summary>
165 /// Train Bayes by telling him that input belongs in tag.
166 /// </summary>
167 /// <param name="tagId">Tag Id</param>
168 /// <param name="input">Input to be trained</param>
169 public void Train(TTagType tagId, string input)
170 {
171 var tokens = _tokenizer.Tokenize(input);
172 var tag = GetAndAddIfNotFound(_tags.Items, tagId);
173 _train(tag, tokens);
174 _tags.SystemTag.TrainCount += 1;
175 tag.TrainCount += 1;
176 _mustRecache = true;
177 }
178  
179 /// <summary>
180 /// Untrain Bayes by telling him that input no more belongs in tag.
181 /// </summary>
182 /// <param name="tagId">Tag Id</param>
183 /// <param name="input">Input to be untrained</param>
184 public void Untrain(TTagType tagId, string input)
185 {
186 var tokens = _tokenizer.Tokenize(input);
187 var tag = _tags.Items[tagId];
188 if (tag == null)
189 {
190 return;
191 }
192 _untrain(tag, tokens);
193 _tags.SystemTag.TrainCount += 1;
194 tag.TrainCount += 1;
195 _mustRecache = true;
196 }
197  
198 /// <summary>
199 /// Returns the scores in each tag the provided input
200 /// </summary>
201 /// <param name="input">Input to be classified</param>
202 public Dictionary<TTagType, double> Classify(string input)
203 {
204 var tokens = _tokenizer.Tokenize(input).ToList();
205 var tags = CreateCacheAnsGetTags();
206  
207 var stats = new Dictionary<TTagType, double>();
208  
209 foreach (var tag in tags.Items)
210 {
211 var probs = GetProbabilities(tag.Value, tokens).ToList();
212 if (probs.Count() != 0)
213 {
214 stats[tag.Key] = _combiner.Combine(probs);
215 }
216 }
217 return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value);
218 }
219  
220 #region Private Methods
221  
222 private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
223 {
224 var tokenCount = 0;
225 foreach (var token in tokens)
226 {
227 var count = tag.Get(token, 0);
228 tag.Items[token] = count + 1;
229 count = _tags.SystemTag.Get(token, 0);
230 _tags.SystemTag.Items[token] = count + 1;
231 tokenCount += 1;
232 }
233 tag.TokenCount += tokenCount;
234 _tags.SystemTag.TokenCount += tokenCount;
235 }
236  
237 private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
238 {
239 foreach (var token in tokens)
240 {
241 var count = tag.Get(token, 0);
242 if (count > 0)
243 {
244 if (Math.Abs(count - 1) < Tolerance)
245 {
246 tag.Items.Remove(token);
247 }
248 else
249 {
250 tag.Items[token] = count - 1;
251 }
252 tag.TokenCount -= 1;
253 }
254 count = _tags.SystemTag.Get(token, 0);
255 if (count > 0)
256 {
257 if (Math.Abs(count - 1) < Tolerance)
258 {
259 _tags.SystemTag.Items.Remove(token);
260 }
261 else
262 {
263 _tags.SystemTag.Items[token] = count - 1;
264 }
265 _tags.SystemTag.TokenCount -= 1;
266 }
267 }
268 }
269  
270 private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key)
271 {
272 if (dic.ContainsKey(key))
273 {
274 return dic[key];
275 }
276 dic[key] = new TagData<TTokenType>();
277 return dic[key];
278 }
279  
280 private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags()
281 {
282 if (!_mustRecache) return _cache;
283  
284 _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag };
285 foreach (var tag in _tags.Items)
286 {
287 var thisTagTokenCount = tag.Value.TokenCount;
288 var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1);
289 var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key);
290  
291 foreach (var systemTagItem in _tags.SystemTag.Items)
292 {
293 var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0);
294 if (Math.Abs(thisTagTokenFreq) < Tolerance)
295 {
296 continue;
297 }
298 var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq;
299  
300 var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount);
301 var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount);
302 var f = badMetric / (goodMetric + badMetric);
303  
304 if (Math.Abs(f - 0.5) >= Threshold)
305 {
306 cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f));
307 }
308 }
309 }
310 _mustRecache = false;
311 return _cache;
312 }
313  
314 private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
315 {
316 var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]);
317 return probs.OrderByDescending(p => p).Take(2048);
318 }
319  
320 #endregion
321  
322 }
323 }