wasBayesSharp – Blame information for rev 1
?pathlinks?
Rev | Author | Line No. | Line |
---|---|---|---|
1 | office | 1 | using System; |
2 | using System.Collections.Generic; |
||
3 | using System.IO; |
||
4 | using System.Linq; |
||
5 | using System.Text; |
||
6 | using BayesSharp.Combiners; |
||
7 | using BayesSharp.Tokenizers; |
||
8 | using Newtonsoft.Json; |
||
9 | |||
10 | namespace BayesSharp |
||
11 | { |
||
12 | public class BayesClassifier<TTokenType, TTagType> where TTagType : IComparable |
||
13 | { |
||
14 | private TagDictionary<TTokenType, TTagType> _tags = new TagDictionary<TTokenType, TTagType>(); |
||
15 | private TagDictionary<TTokenType, TTagType> _cache; |
||
16 | |||
17 | private readonly ITokenizer<TTokenType> _tokenizer; |
||
18 | private readonly ICombiner _combiner; |
||
19 | |||
20 | private bool _mustRecache; |
||
21 | private const double Tolerance = 0.0001; |
||
22 | private const double Threshold = 0.1; |
||
23 | |||
24 | public BayesClassifier(ITokenizer<TTokenType> tokenizer) |
||
25 | : this(tokenizer, new RobinsonCombiner()) |
||
26 | { |
||
27 | } |
||
28 | |||
29 | public BayesClassifier(ITokenizer<TTokenType> tokenizer, ICombiner combiner) |
||
30 | { |
||
31 | if (tokenizer == null) throw new ArgumentNullException("tokenizer"); |
||
32 | if (combiner == null) throw new ArgumentNullException("combiner"); |
||
33 | |||
34 | _tokenizer = tokenizer; |
||
35 | _combiner = combiner; |
||
36 | |||
37 | _tags.SystemTag = new TagData<TTokenType>(); |
||
38 | _mustRecache = true; |
||
39 | } |
||
40 | |||
41 | /// <summary> |
||
42 | /// Create a new tag, without actually doing any training. |
||
43 | /// </summary> |
||
44 | /// <param name="tagId">Tag Id</param> |
||
45 | public void AddTag(TTagType tagId) |
||
46 | { |
||
47 | GetAndAddIfNotFound(_tags.Items, tagId); |
||
48 | _mustRecache = true; |
||
49 | } |
||
50 | |||
51 | /// <summary> |
||
52 | /// Remove a tag |
||
53 | /// </summary> |
||
54 | /// <param name="tagId">Tag Id</param> |
||
55 | public void RemoveTag(TTagType tagId) |
||
56 | { |
||
57 | _tags.Items.Remove(tagId); |
||
58 | _mustRecache = true; |
||
59 | } |
||
60 | |||
61 | /// <summary> |
||
62 | /// Change the Id of a tag |
||
63 | /// </summary> |
||
64 | /// <param name="oldTagId">Old Tag Id</param> |
||
65 | /// <param name="newTagId">New Tag Id</param> |
||
66 | public void ChangeTagId(TTagType oldTagId, TTagType newTagId) |
||
67 | { |
||
68 | _tags.Items[newTagId] = _tags.Items[oldTagId]; |
||
69 | RemoveTag(oldTagId); |
||
70 | _mustRecache = true; |
||
71 | } |
||
72 | |||
73 | /// <summary> |
||
74 | /// Merge an existing tag into another |
||
75 | /// </summary> |
||
76 | /// <param name="sourceTagId">Tag to merged to destTagId and removed</param> |
||
77 | /// <param name="destTagId">Destination tag Id</param> |
||
78 | public void MergeTags(TTagType sourceTagId, TTagType destTagId) |
||
79 | { |
||
80 | var sourceTag = _tags.Items[sourceTagId]; |
||
81 | var destTag = _tags.Items[destTagId]; |
||
82 | var count = 0; |
||
83 | foreach (var tagItem in sourceTag.Items) |
||
84 | { |
||
85 | count++; |
||
86 | var tok = tagItem; |
||
87 | if (destTag.Items.ContainsKey(tok.Key)) |
||
88 | { |
||
89 | destTag.Items[tok.Key] += count; |
||
90 | } |
||
91 | else |
||
92 | { |
||
93 | destTag.Items[tok.Key] = count; |
||
94 | destTag.TokenCount += 1; |
||
95 | } |
||
96 | } |
||
97 | RemoveTag(sourceTagId); |
||
98 | _mustRecache = true; |
||
99 | } |
||
100 | |||
101 | /// <summary> |
||
102 | /// Return a TagData object of a Tag Id informed |
||
103 | /// </summary> |
||
104 | /// <param name="tagId">Tag Id</param> |
||
105 | public TagData<TTokenType> GetTagById(TTagType tagId) |
||
106 | { |
||
107 | return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null; |
||
108 | } |
||
109 | |||
110 | /// <summary> |
||
111 | /// Save Bayes Text Classifier into a file |
||
112 | /// </summary> |
||
113 | /// <param name="path">The file to write to</param> |
||
114 | public void Save(string path) |
||
115 | { |
||
116 | using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8)) |
||
117 | { |
||
118 | JsonSerializer.Create().Serialize(streamWriter, _tags); |
||
119 | } |
||
120 | } |
||
121 | |||
122 | /// <summary> |
||
123 | /// Load Bayes Text Classifier from a file |
||
124 | /// </summary> |
||
125 | /// <param name="path">The file to open for reading</param> |
||
126 | public void Load(string path) |
||
127 | { |
||
128 | using (var streamReader = new StreamReader(path, Encoding.UTF8)) |
||
129 | { |
||
130 | using (var jsonTextReader = new JsonTextReader(streamReader)) |
||
131 | { |
||
132 | _tags = JsonSerializer.Create().Deserialize<TagDictionary<TTokenType, TTagType>>(jsonTextReader); |
||
133 | } |
||
134 | } |
||
135 | _mustRecache = true; |
||
136 | } |
||
137 | |||
138 | /// <summary> |
||
139 | /// Import Bayes Text Classifier from a json string |
||
140 | /// </summary> |
||
141 | /// <param name="json">The json content to be loaded</param> |
||
142 | public void ImportJsonData(string json) |
||
143 | { |
||
144 | _tags = JsonConvert.DeserializeObject<TagDictionary<TTokenType, TTagType>>(json); |
||
145 | _mustRecache = true; |
||
146 | } |
||
147 | |||
148 | /// <summary> |
||
149 | /// Export Bayes Text Classifier to a json string |
||
150 | /// </summary> |
||
151 | public string ExportJsonData() |
||
152 | { |
||
153 | return JsonConvert.SerializeObject(_tags); |
||
154 | } |
||
155 | |||
156 | /// <summary> |
||
157 | /// Return a sorted list of Tag Ids |
||
158 | /// </summary> |
||
159 | public IEnumerable<TTagType> TagIds() |
||
160 | { |
||
161 | return _tags.Items.Keys.OrderBy(p => p); |
||
162 | } |
||
163 | |||
164 | /// <summary> |
||
165 | /// Train Bayes by telling him that input belongs in tag. |
||
166 | /// </summary> |
||
167 | /// <param name="tagId">Tag Id</param> |
||
168 | /// <param name="input">Input to be trained</param> |
||
169 | public void Train(TTagType tagId, string input) |
||
170 | { |
||
171 | var tokens = _tokenizer.Tokenize(input); |
||
172 | var tag = GetAndAddIfNotFound(_tags.Items, tagId); |
||
173 | _train(tag, tokens); |
||
174 | _tags.SystemTag.TrainCount += 1; |
||
175 | tag.TrainCount += 1; |
||
176 | _mustRecache = true; |
||
177 | } |
||
178 | |||
179 | /// <summary> |
||
180 | /// Untrain Bayes by telling him that input no more belongs in tag. |
||
181 | /// </summary> |
||
182 | /// <param name="tagId">Tag Id</param> |
||
183 | /// <param name="input">Input to be untrained</param> |
||
184 | public void Untrain(TTagType tagId, string input) |
||
185 | { |
||
186 | var tokens = _tokenizer.Tokenize(input); |
||
187 | var tag = _tags.Items[tagId]; |
||
188 | if (tag == null) |
||
189 | { |
||
190 | return; |
||
191 | } |
||
192 | _untrain(tag, tokens); |
||
193 | _tags.SystemTag.TrainCount += 1; |
||
194 | tag.TrainCount += 1; |
||
195 | _mustRecache = true; |
||
196 | } |
||
197 | |||
198 | /// <summary> |
||
199 | /// Returns the scores in each tag the provided input |
||
200 | /// </summary> |
||
201 | /// <param name="input">Input to be classified</param> |
||
202 | public Dictionary<TTagType, double> Classify(string input) |
||
203 | { |
||
204 | var tokens = _tokenizer.Tokenize(input).ToList(); |
||
205 | var tags = CreateCacheAnsGetTags(); |
||
206 | |||
207 | var stats = new Dictionary<TTagType, double>(); |
||
208 | |||
209 | foreach (var tag in tags.Items) |
||
210 | { |
||
211 | var probs = GetProbabilities(tag.Value, tokens).ToList(); |
||
212 | if (probs.Count() != 0) |
||
213 | { |
||
214 | stats[tag.Key] = _combiner.Combine(probs); |
||
215 | } |
||
216 | } |
||
217 | return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value); |
||
218 | } |
||
219 | |||
220 | #region Private Methods |
||
221 | |||
222 | private void _train(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
223 | { |
||
224 | var tokenCount = 0; |
||
225 | foreach (var token in tokens) |
||
226 | { |
||
227 | var count = tag.Get(token, 0); |
||
228 | tag.Items[token] = count + 1; |
||
229 | count = _tags.SystemTag.Get(token, 0); |
||
230 | _tags.SystemTag.Items[token] = count + 1; |
||
231 | tokenCount += 1; |
||
232 | } |
||
233 | tag.TokenCount += tokenCount; |
||
234 | _tags.SystemTag.TokenCount += tokenCount; |
||
235 | } |
||
236 | |||
237 | private void _untrain(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
238 | { |
||
239 | foreach (var token in tokens) |
||
240 | { |
||
241 | var count = tag.Get(token, 0); |
||
242 | if (count > 0) |
||
243 | { |
||
244 | if (Math.Abs(count - 1) < Tolerance) |
||
245 | { |
||
246 | tag.Items.Remove(token); |
||
247 | } |
||
248 | else |
||
249 | { |
||
250 | tag.Items[token] = count - 1; |
||
251 | } |
||
252 | tag.TokenCount -= 1; |
||
253 | } |
||
254 | count = _tags.SystemTag.Get(token, 0); |
||
255 | if (count > 0) |
||
256 | { |
||
257 | if (Math.Abs(count - 1) < Tolerance) |
||
258 | { |
||
259 | _tags.SystemTag.Items.Remove(token); |
||
260 | } |
||
261 | else |
||
262 | { |
||
263 | _tags.SystemTag.Items[token] = count - 1; |
||
264 | } |
||
265 | _tags.SystemTag.TokenCount -= 1; |
||
266 | } |
||
267 | } |
||
268 | } |
||
269 | |||
270 | private static TagData<TTokenType> GetAndAddIfNotFound(IDictionary<TTagType, TagData<TTokenType>> dic, TTagType key) |
||
271 | { |
||
272 | if (dic.ContainsKey(key)) |
||
273 | { |
||
274 | return dic[key]; |
||
275 | } |
||
276 | dic[key] = new TagData<TTokenType>(); |
||
277 | return dic[key]; |
||
278 | } |
||
279 | |||
280 | private TagDictionary<TTokenType, TTagType> CreateCacheAnsGetTags() |
||
281 | { |
||
282 | if (!_mustRecache) return _cache; |
||
283 | |||
284 | _cache = new TagDictionary<TTokenType, TTagType> { SystemTag = _tags.SystemTag }; |
||
285 | foreach (var tag in _tags.Items) |
||
286 | { |
||
287 | var thisTagTokenCount = tag.Value.TokenCount; |
||
288 | var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1); |
||
289 | var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key); |
||
290 | |||
291 | foreach (var systemTagItem in _tags.SystemTag.Items) |
||
292 | { |
||
293 | var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0); |
||
294 | if (Math.Abs(thisTagTokenFreq) < Tolerance) |
||
295 | { |
||
296 | continue; |
||
297 | } |
||
298 | var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq; |
||
299 | |||
300 | var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount); |
||
301 | var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount); |
||
302 | var f = badMetric / (goodMetric + badMetric); |
||
303 | |||
304 | if (Math.Abs(f - 0.5) >= Threshold) |
||
305 | { |
||
306 | cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f)); |
||
307 | } |
||
308 | } |
||
309 | } |
||
310 | _mustRecache = false; |
||
311 | return _cache; |
||
312 | } |
||
313 | |||
314 | private static IEnumerable<double> GetProbabilities(TagData<TTokenType> tag, IEnumerable<TTokenType> tokens) |
||
315 | { |
||
316 | var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]); |
||
317 | return probs.OrderByDescending(p => p).Take(2048); |
||
318 | } |
||
319 | |||
320 | #endregion |
||
321 | |||
322 | } |
||
323 | } |