wasBayesSharp – Blame information for rev 3

Subversion Repositories:
Rev:
Rev Author Line No. Line
1 office 1 using System;
2 using System.Collections.Generic;
3 using System.IO;
4 using System.Linq;
5 using System.Text;
6 using BayesSharp.Combiners;
7 using BayesSharp.Tokenizers;
8 using Newtonsoft.Json;
9  
10 namespace BayesSharp
11 {
12 public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable
13 {
14 private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>();
15 private TagDictionary<TTokenType, TTagType> _cache;
16  
17 private readonly ITokenizer<TTokenType> _tokenizer;
18 private readonly ICombiner _combiner;
19  
20 private bool _mustRecache;
21 private const double Tolerance = 0.0001;
22 private const double Threshold = 0.1;
23  
24 public BayesClassifier(ITokenizer<TTokenType> tokenizer)
25 : this(tokenizer, new RobinsonCombiner())
26 {
27 }
28  
29 public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner)
30 {
31 if (tokenizer == null) throw new ArgumentNullException("tokenizer");
32 if (combiner == null) throw new ArgumentNullException("combiner");
33  
34 _tokenizer = tokenizer;
35 _combiner = combiner;
36  
37 _tags.SystemTag = new TagData<TTokenType>();
38 _mustRecache = true;
39 }
40  
41 /// <summary>
42 /// Create a new tag, without actually doing any training.
43 /// </summary>
44 /// <param name="tagId">Tag Id</param>
45 public void AddTag(TTagType tagId)
46 {
47 GetAndAddIfNotFound(_tags.Items, tagId);
48 _mustRecache = true;
49 }
50  
51 /// <summary>
52 /// Remove a tag
53 /// </summary>
54 /// <param name="tagId">Tag Id</param>
55 public void RemoveTag(TTagType tagId)
56 {
57 _tags.Items.Remove(tagId);
58 _mustRecache = true;
59 }
60  
61 /// <summary>
62 /// Change the Id of a tag
63 /// </summary>
64 /// <param name="oldTagId">Old Tag Id</param>
65 /// <param name="newTagId">New Tag Id</param>
66 public void ChangeTagId(TTagType oldTagId, TTagType newTagId)
67 {
68 _tags.Items[newTagId] = _tags.Items[oldTagId];
69 RemoveTag(oldTagId);
70 _mustRecache = true;
71 }
72  
73 /// <summary>
74 /// Merge an existing tag into another
75 /// </summary>
76 /// <param name="sourceTagId">Tag to merged to destTagId and removed</param>
77 /// <param name="destTagId">Destination tag Id</param>
78 public void MergeTags(TTagType sourceTagId, TTagType destTagId)
79 {
80 var sourceTag = _tags.Items[sourceTagId];
81 var destTag = _tags.Items[destTagId];
82 var count = 0;
83 foreach (var tagItem in sourceTag.Items)
84 {
85 count++;
86 var tok = tagItem;
87 if (destTag.Items.ContainsKey(tok.Key))
88 {
89 destTag.Items[tok.Key] += count;
90 }
91 else
92 {
93 destTag.Items[tok.Key] = count;
94 destTag.TokenCount += 1;
95 }
96 }
97 RemoveTag(sourceTagId);
98 _mustRecache = true;
99 }
100  
101 /// <summary>
102 /// Return a TagData object of a Tag Id informed
103 /// </summary>
104 /// <param name="tagId">Tag Id</param>
105 public TagData<TTokenType> GetTagById(TTagType tagId)
106 {
107 return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null;
108 }
109  
110 /// <summary>
111 /// Save Bayes Text Classifier into a file
112 /// </summary>
113 /// <param name="path">The file to write to</param>
114 public void Save(string path)
115 {
116 using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8))
117 {
118 JsonSerializer.Create().Serialize(streamWriter, _tags);
119 }
120 }
121  
122 /// <summary>
123 /// Load Bayes Text Classifier from a file
124 /// </summary>
125 /// <param name="path">The file to open for reading</param>
126 public void Load(string path)
127 {
128 using (var streamReader = new StreamReader(path, Encoding.UTF8))
129 {
130 using (var jsonTextReader = new JsonTextReader(streamReader))
131 {
132 _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader);
133 }
134 }
135 _mustRecache = true;
136 }
137  
138 /// <summary>
139 /// Import Bayes Text Classifier from a json string
140 /// </summary>
141 /// <param name="json">The json content to be loaded</param>
142 public void ImportJsonData(string json)
143 {
3 office 144 var result = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json);
145 switch (result != null)
146 {
147 case true:
148 _tags = result;
149 _mustRecache = true;
150 break;
151  
152 default:
153 _tags = new TagDictionary<TTokenType, TTagType>();
154 break;
155 }
1 office 156 }
157  
158 /// <summary>
159 /// Export Bayes Text Classifier to a json string
160 /// </summary>
161 public string ExportJsonData()
162 {
3 office 163 return _tags?.Items != null &&
164 _tags.Items.Any() ?
165 JsonConvert.SerializeObject(_tags) :
166 string.Empty;
1 office 167 }
168  
169 /// <summary>
170 /// Return a sorted list of Tag Ids
171 /// </summary>
172 public IEnumerable<TTagType> TagIds()
173 {
174 return _tags.Items.Keys.OrderBy(p => p);
175 }
176  
177 /// <summary>
178 /// Train Bayes by telling him that input belongs in tag.
179 /// </summary>
180 /// <param name="tagId">Tag Id</param>
181 /// <param name="input">Input to be trained</param>
182 public void Train(TTagType tagId, string input)
183 {
184 var tokens = _tokenizer.Tokenize(input);
185 var tag = GetAndAddIfNotFound(_tags.Items, tagId);
186 _train(tag, tokens);
187 _tags.SystemTag.TrainCount += 1;
188 tag.TrainCount += 1;
189 _mustRecache = true;
190 }
191  
192 /// <summary>
193 /// Untrain Bayes by telling him that input no more belongs in tag.
194 /// </summary>
195 /// <param name="tagId">Tag Id</param>
196 /// <param name="input">Input to be untrained</param>
197 public void Untrain(TTagType tagId, string input)
198 {
199 var tokens = _tokenizer.Tokenize(input);
200 var tag = _tags.Items[tagId];
201 if (tag == null)
202 {
203 return;
204 }
205 _untrain(tag, tokens);
206 _tags.SystemTag.TrainCount += 1;
207 tag.TrainCount += 1;
208 _mustRecache = true;
209 }
210  
211 /// <summary>
212 /// Returns the scores in each tag the provided input
213 /// </summary>
214 /// <param name="input">Input to be classified</param>
215 public Dictionary<TTagType, double> Classify(string input)
216 {
217 var tokens = _tokenizer.Tokenize(input).ToList();
218 var tags = CreateCacheAnsGetTags();
219  
220 var stats = new Dictionary<TTagType, double>();
221  
222 foreach (var tag in tags.Items)
223 {
224 var probs = GetProbabilities(tag.Value, tokens).ToList();
225 if (probs.Count() != 0)
226 {
227 stats[tag.Key] = _combiner.Combine(probs);
228 }
229 }
230 return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value);
231 }
232  
233 #region Private Methods
234  
235 private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
236 {
237 var tokenCount = 0;
238 foreach (var token in tokens)
239 {
240 var count = tag.Get(token, 0);
241 tag.Items[token] = count + 1;
242 count = _tags.SystemTag.Get(token, 0);
243 _tags.SystemTag.Items[token] = count + 1;
244 tokenCount += 1;
245 }
246 tag.TokenCount += tokenCount;
247 _tags.SystemTag.TokenCount += tokenCount;
248 }
249  
250 private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
251 {
252 foreach (var token in tokens)
253 {
254 var count = tag.Get(token, 0);
255 if (count > 0)
256 {
257 if (Math.Abs(count - 1) < Tolerance)
258 {
259 tag.Items.Remove(token);
260 }
261 else
262 {
263 tag.Items[token] = count - 1;
264 }
265 tag.TokenCount -= 1;
266 }
267 count = _tags.SystemTag.Get(token, 0);
268 if (count > 0)
269 {
270 if (Math.Abs(count - 1) < Tolerance)
271 {
272 _tags.SystemTag.Items.Remove(token);
273 }
274 else
275 {
276 _tags.SystemTag.Items[token] = count - 1;
277 }
278 _tags.SystemTag.TokenCount -= 1;
279 }
280 }
281 }
282  
283 private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key)
284 {
285 if (dic.ContainsKey(key))
286 {
287 return dic[key];
288 }
289 dic[key] = new TagData<TTokenType>();
290 return dic[key];
291 }
292  
293 private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags()
294 {
295 if (!_mustRecache) return _cache;
296  
297 _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag };
298 foreach (var tag in _tags.Items)
299 {
300 var thisTagTokenCount = tag.Value.TokenCount;
301 var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1);
302 var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key);
303  
304 foreach (var systemTagItem in _tags.SystemTag.Items)
305 {
306 var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0);
307 if (Math.Abs(thisTagTokenFreq) < Tolerance)
308 {
309 continue;
310 }
311 var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq;
312  
313 var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount);
314 var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount);
315 var f = badMetric / (goodMetric + badMetric);
316  
317 if (Math.Abs(f - 0.5) >= Threshold)
318 {
319 cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f));
320 }
321 }
322 }
323 _mustRecache = false;
324 return _cache;
325 }
326  
327 private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens)
328 {
329 var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]);
330 return probs.OrderByDescending(p => p).Take(2048);
331 }
332  
3 office 333 #endregion Private Methods
1 office 334 }
3 office 335 }