-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnaive_bayes.py
63 lines (51 loc) · 1.97 KB
/
naive_bayes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
###########################################################
# DO NOT EDIT ANYTHING
###########################################################
from probabilities import *
class NaiveBayesSpamFilter:
def __init__(self, vocab_size):
'''
probabilities are initialized and need to be calculated in the train method
the size of the vocabulary in the training set is provided
'''
self.p_word_given_spam = dict()
self.p_word_given_ham = dict()
self.p_spam = 0
self.p_ham = 0
self.vocab_size = vocab_size
def train(self, X_paths_spam, X_paths_ham):
'''
-- Inputs
paths to known spam and ham emails
'''
self.p_word_given_spam = compute_p_word_given_class(X_paths_spam, self.vocab_size)
self.p_word_given_ham = compute_p_word_given_class(X_paths_ham, self.vocab_size)
self.p_spam = compute_p_class(len(X_paths_spam), len(X_paths_ham))
self.p_ham = compute_p_class(len(X_paths_ham), len(X_paths_spam))
def predict(self, X_instance_path):
'''
-- Input
path to input email
-- Ouput
classification label
'''
p_spam_given_input = compute_p_class_given_input(X_instance_path, self.p_word_given_spam, self.p_spam)
p_ham_given_input = compute_p_class_given_input(X_instance_path, self.p_word_given_ham, self.p_ham)
if p_spam_given_input > p_ham_given_input:
return 1
else:
return 0
def evaluate(self, X_paths, ground_truth_class):
'''
-- Inputs
paths to emails that have the same ground truth class
the ground truth class
-- Ouput
accuracy score
'''
gt = 1 if ground_truth_class == 'spam' else 0
count = 0
for path in X_paths:
if self.predict(path) == gt:
count += 1
return float(count)/len(X_paths)