wasBayesSharp – Blame information for rev 3
?pathlinks?
Rev | Author | Line No. | Line |
---|---|---|---|
1 | office | 1 | using System; |
2 | using System.Collections.Generic; |
||
3 | using System.IO; |
||
4 | using System.Linq; |
||
5 | using System.Text; |
||
6 | using BayesSharp.Combiners; |
||
7 | using BayesSharp.Tokenizers; |
||
8 | using Newtonsoft.Json; |
||
9 | |||
10 | namespace BayesSharp |
||
11 | { |
||
12 | public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable |
||
13 | { |
||
14 | private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>(); |
||
15 | private TagDictionary<TTokenType, TTagType> _cache; |
||
16 | |||
17 | private readonly ITokenizer<TTokenType> _tokenizer; |
||
18 | private readonly ICombiner _combiner; |
||
19 | |||
20 | private bool _mustRecache; |
||
21 | private const double Tolerance = 0.0001; |
||
22 | private const double Threshold = 0.1; |
||
23 | |||
24 | public BayesClassifier(ITokenizer<TTokenType> tokenizer) |
||
25 | : this(tokenizer, new RobinsonCombiner()) |
||
26 | { |
||
27 | } |
||
28 | |||
29 | public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner) |
||
30 | { |
||
31 | if (tokenizer == null) throw new ArgumentNullException("tokenizer"); |
||
32 | if (combiner == null) throw new ArgumentNullException("combiner"); |
||
33 | |||
34 | _tokenizer = tokenizer; |
||
35 | _combiner = combiner; |
||
36 | |||
37 | _tags.SystemTag = new TagData<TTokenType>(); |
||
38 | _mustRecache = true; |
||
39 | } |
||
40 | |||
41 | /// <summary> |
||
42 | /// Create a new tag, without actually doing any training. |
||
43 | /// </summary> |
||
44 | /// <param name="tagId">Tag Id</param> |
||
45 | public void AddTag(TTagType tagId) |
||
46 | { |
||
47 | GetAndAddIfNotFound(_tags.Items, tagId); |
||
48 | _mustRecache = true; |
||
49 | } |
||
50 | |||
51 | /// <summary> |
||
52 | /// Remove a tag |
||
53 | /// </summary> |
||
54 | /// <param name="tagId">Tag Id</param> |
||
55 | public void RemoveTag(TTagType tagId) |
||
56 | { |
||
57 | _tags.Items.Remove(tagId); |
||
58 | _mustRecache = true; |
||
59 | } |
||
60 | |||
61 | /// <summary> |
||
62 | /// Change the Id of a tag |
||
63 | /// </summary> |
||
64 | /// <param name="oldTagId">Old Tag Id</param> |
||
65 | /// <param name="newTagId">New Tag Id</param> |
||
66 | public void ChangeTagId(TTagType oldTagId, TTagType newTagId) |
||
67 | { |
||
68 | _tags.Items[newTagId] = _tags.Items[oldTagId]; |
||
69 | RemoveTag(oldTagId); |
||
70 | _mustRecache = true; |
||
71 | } |
||
72 | |||
73 | /// <summary> |
||
74 | /// Merge an existing tag into another |
||
75 | /// </summary> |
||
76 | /// <param name="sourceTagId">Tag to merged to destTagId and removed</param> |
||
77 | /// <param name="destTagId">Destination tag Id</param> |
||
78 | public void MergeTags(TTagType sourceTagId, TTagType destTagId) |
||
79 | { |
||
80 | var sourceTag = _tags.Items[sourceTagId]; |
||
81 | var destTag = _tags.Items[destTagId]; |
||
82 | var count = 0; |
||
83 | foreach (var tagItem in sourceTag.Items) |
||
84 | { |
||
85 | count++; |
||
86 | var tok = tagItem; |
||
87 | if (destTag.Items.ContainsKey(tok.Key)) |
||
88 | { |
||
89 | destTag.Items[tok.Key] += count; |
||
90 | } |
||
91 | else |
||
92 | { |
||
93 | destTag.Items[tok.Key] = count; |
||
94 | destTag.TokenCount += 1; |
||
95 | } |
||
96 | } |
||
97 | RemoveTag(sourceTagId); |
||
98 | _mustRecache = true; |
||
99 | } |
||
100 | |||
101 | /// <summary> |
||
102 | /// Return a TagData object of a Tag Id informed |
||
103 | /// </summary> |
||
104 | /// <param name="tagId">Tag Id</param> |
||
105 | public TagData<TTokenType> GetTagById(TTagType tagId) |
||
106 | { |
||
107 | return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null; |
||
108 | } |
||
109 | |||
110 | /// <summary> |
||
111 | /// Save Bayes Text Classifier into a file |
||
112 | /// </summary> |
||
113 | /// <param name="path">The file to write to</param> |
||
114 | public void Save(string path) |
||
115 | { |
||
116 | using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8)) |
||
117 | { |
||
118 | JsonSerializer.Create().Serialize(streamWriter, _tags); |
||
119 | } |
||
120 | } |
||
121 | |||
122 | /// <summary> |
||
123 | /// Load Bayes Text Classifier from a file |
||
124 | /// </summary> |
||
125 | /// <param name="path">The file to open for reading</param> |
||
126 | public void Load(string path) |
||
127 | { |
||
128 | using (var streamReader = new StreamReader(path, Encoding.UTF8)) |
||
129 | { |
||
130 | using (var jsonTextReader = new JsonTextReader(streamReader)) |
||
131 | { |
||
132 | _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader); |
||
133 | } |
||
134 | } |
||
135 | _mustRecache = true; |
||
136 | } |
||
137 | |||
138 | /// <summary> |
||
139 | /// Import Bayes Text Classifier from a json string |
||
140 | /// </summary> |
||
141 | /// <param name="json">The json content to be loaded</param> |
||
142 | public void ImportJsonData(string json) |
||
143 | { |
||
3 | office | 144 | var result = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json); |
145 | switch (result != null) |
||
146 | { |
||
147 | case true: |
||
148 | _tags = result; |
||
149 | _mustRecache = true; |
||
150 | break; |
||
151 | |||
152 | default: |
||
153 | _tags = new TagDictionary<TTokenType, TTagType>(); |
||
154 | break; |
||
155 | } |
||
1 | office | 156 | } |
157 | |||
158 | /// <summary> |
||
159 | /// Export Bayes Text Classifier to a json string |
||
160 | /// </summary> |
||
161 | public string ExportJsonData() |
||
162 | { |
||
3 | office | 163 | return _tags?.Items != null && |
164 | _tags.Items.Any() ? |
||
165 | JsonConvert.SerializeObject(_tags) : |
||
166 | string.Empty; |
||
1 | office | 167 | } |
168 | |||
169 | /// <summary> |
||
170 | /// Return a sorted list of Tag Ids |
||
171 | /// </summary> |
||
172 | public IEnumerable<TTagType> TagIds() |
||
173 | { |
||
174 | return _tags.Items.Keys.OrderBy(p => p); |
||
175 | } |
||
176 | |||
177 | /// <summary> |
||
178 | /// Train Bayes by telling him that input belongs in tag. |
||
179 | /// </summary> |
||
180 | /// <param name="tagId">Tag Id</param> |
||
181 | /// <param name="input">Input to be trained</param> |
||
182 | public void Train(TTagType tagId, string input) |
||
183 | { |
||
184 | var tokens = _tokenizer.Tokenize(input); |
||
185 | var tag = GetAndAddIfNotFound(_tags.Items, tagId); |
||
186 | _train(tag, tokens); |
||
187 | _tags.SystemTag.TrainCount += 1; |
||
188 | tag.TrainCount += 1; |
||
189 | _mustRecache = true; |
||
190 | } |
||
191 | |||
192 | /// <summary> |
||
193 | /// Untrain Bayes by telling him that input no more belongs in tag. |
||
194 | /// </summary> |
||
195 | /// <param name="tagId">Tag Id</param> |
||
196 | /// <param name="input">Input to be untrained</param> |
||
197 | public void Untrain(TTagType tagId, string input) |
||
198 | { |
||
199 | var tokens = _tokenizer.Tokenize(input); |
||
200 | var tag = _tags.Items[tagId]; |
||
201 | if (tag == null) |
||
202 | { |
||
203 | return; |
||
204 | } |
||
205 | _untrain(tag, tokens); |
||
206 | _tags.SystemTag.TrainCount += 1; |
||
207 | tag.TrainCount += 1; |
||
208 | _mustRecache = true; |
||
209 | } |
||
210 | |||
211 | /// <summary> |
||
212 | /// Returns the scores in each tag the provided input |
||
213 | /// </summary> |
||
214 | /// <param name="input">Input to be classified</param> |
||
215 | public Dictionary<TTagType, double> Classify(string input) |
||
216 | { |
||
217 | var tokens = _tokenizer.Tokenize(input).ToList(); |
||
218 | var tags = CreateCacheAnsGetTags(); |
||
219 | |||
220 | var stats = new Dictionary<TTagType, double>(); |
||
221 | |||
222 | foreach (var tag in tags.Items) |
||
223 | { |
||
224 | var probs = GetProbabilities(tag.Value, tokens).ToList(); |
||
225 | if (probs.Count() != 0) |
||
226 | { |
||
227 | stats[tag.Key] = _combiner.Combine(probs); |
||
228 | } |
||
229 | } |
||
230 | return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value); |
||
231 | } |
||
232 | |||
233 | #region Private Methods |
||
234 | |||
235 | private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
236 | { |
||
237 | var tokenCount = 0; |
||
238 | foreach (var token in tokens) |
||
239 | { |
||
240 | var count = tag.Get(token, 0); |
||
241 | tag.Items[token] = count + 1; |
||
242 | count = _tags.SystemTag.Get(token, 0); |
||
243 | _tags.SystemTag.Items[token] = count + 1; |
||
244 | tokenCount += 1; |
||
245 | } |
||
246 | tag.TokenCount += tokenCount; |
||
247 | _tags.SystemTag.TokenCount += tokenCount; |
||
248 | } |
||
249 | |||
250 | private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
251 | { |
||
252 | foreach (var token in tokens) |
||
253 | { |
||
254 | var count = tag.Get(token, 0); |
||
255 | if (count > 0) |
||
256 | { |
||
257 | if (Math.Abs(count - 1) < Tolerance) |
||
258 | { |
||
259 | tag.Items.Remove(token); |
||
260 | } |
||
261 | else |
||
262 | { |
||
263 | tag.Items[token] = count - 1; |
||
264 | } |
||
265 | tag.TokenCount -= 1; |
||
266 | } |
||
267 | count = _tags.SystemTag.Get(token, 0); |
||
268 | if (count > 0) |
||
269 | { |
||
270 | if (Math.Abs(count - 1) < Tolerance) |
||
271 | { |
||
272 | _tags.SystemTag.Items.Remove(token); |
||
273 | } |
||
274 | else |
||
275 | { |
||
276 | _tags.SystemTag.Items[token] = count - 1; |
||
277 | } |
||
278 | _tags.SystemTag.TokenCount -= 1; |
||
279 | } |
||
280 | } |
||
281 | } |
||
282 | |||
283 | private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key) |
||
284 | { |
||
285 | if (dic.ContainsKey(key)) |
||
286 | { |
||
287 | return dic[key]; |
||
288 | } |
||
289 | dic[key] = new TagData<TTokenType>(); |
||
290 | return dic[key]; |
||
291 | } |
||
292 | |||
293 | private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags() |
||
294 | { |
||
295 | if (!_mustRecache) return _cache; |
||
296 | |||
297 | _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag }; |
||
298 | foreach (var tag in _tags.Items) |
||
299 | { |
||
300 | var thisTagTokenCount = tag.Value.TokenCount; |
||
301 | var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1); |
||
302 | var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key); |
||
303 | |||
304 | foreach (var systemTagItem in _tags.SystemTag.Items) |
||
305 | { |
||
306 | var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0); |
||
307 | if (Math.Abs(thisTagTokenFreq) < Tolerance) |
||
308 | { |
||
309 | continue; |
||
310 | } |
||
311 | var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq; |
||
312 | |||
313 | var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount); |
||
314 | var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount); |
||
315 | var f = badMetric / (goodMetric + badMetric); |
||
316 | |||
317 | if (Math.Abs(f - 0.5) >= Threshold) |
||
318 | { |
||
319 | cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f)); |
||
320 | } |
||
321 | } |
||
322 | } |
||
323 | _mustRecache = false; |
||
324 | return _cache; |
||
325 | } |
||
326 | |||
327 | private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
328 | { |
||
329 | var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]); |
||
330 | return probs.OrderByDescending(p => p).Take(2048); |
||
331 | } |
||
332 | |||
3 | office | 333 | #endregion Private Methods |
1 | office | 334 | } |
3 | office | 335 | } |