-
Notifications
You must be signed in to change notification settings - Fork 1
/
Conversation.py
93 lines (75 loc) · 3.29 KB
/
Conversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
from WordMarkovChain import WordMarkovChain
import operator
import random
class Conversation:
def __init__(self, chat_id):
self.chat_id = chat_id
self.chain = WordMarkovChain()
self.reverse_chain = WordMarkovChain()
self.someones = set()
def add_message(self, message):
message = message.lower().split()
self.chain.add_message(message)
message.reverse()
self.reverse_chain.add_message(message)
def add_someone(self, someone):
self.someones.add(someone)
def is_there_someone(self):
return len(self.someones) > 0
def get_someone(self, quantity=False):
if quantity:
if len(self.someones) >= quantity:
return random.sample(self.someones, quantity)
else:
if len(self.someones) > 0:
return random.sample(self.someones, 1)[0]
def get_someones(self):
return ", ".join(self.someones)
def generate_message(self):
return self.chain.build_message()
def generate_message_beginning_with(self, words):
if len(words) > 0:
generated_message = self.chain.build_message(words[-1]).split()
return " ".join(words + generated_message[1:]).lower()
def generate_message_ending_with(self, words):
if len(words) > 0:
generated_message = self.reverse_chain.build_message(words[0]).split()
generated_message.reverse()
return " ".join(generated_message[:-1] + words).lower()
def generate_message_containing(self, words):
if len(words) > 0:
message_beginning = self.reverse_chain.build_message(words[0]).split()
message_beginning.reverse()
message_end = self.chain.build_message(words[-1]).split()
return " ".join(message_beginning[:-1] + words + message_end[1:]).lower()
def print_chain(self, word, reverse=False):
arg = word.lower()
if not reverse:
probabilities = self.chain.probabilities_for(arg)
printed_chain = "Probabilities to appear after '" + arg + "':"
message_extreme = "End of message"
else:
probabilities = self.reverse_chain.probabilities_for(arg)
printed_chain = "Probabilities to appear before '" + arg + "':"
message_extreme = "Beginning of message"
if probabilities:
for word, prob in reversed(sorted(probabilities.items(), key=operator.itemgetter(1))):
if word:
printed_chain += "\n - '" + word.decode("utf-8") + "': " + unicode(prob)
else:
printed_chain += "\n - " + message_extreme + ": " + unicode(prob)
else:
printed_chain = "The word '" + arg + "' doesn't seem to be in my database"
return printed_chain
def set_randomness(self, p):
if 0 <= p <= 1:
self.chain.set_randomness(p)
self.reverse_chain.set_randomness(p)
else:
raise (ValueError, "Randomness should be a number between 0 and 1")
def import_chain(self, filename):
self.chain.import_chain(filename)
self.reverse_chain.import_chain(filename, reverse=True)
def export_chain(self, filename):
self.chain.export_chain(filename)