-
Notifications
You must be signed in to change notification settings - Fork 81
/
db_query.py
185 lines (147 loc) · 7.57 KB
/
db_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from collections import defaultdict
from dialogue_config import no_query_keys, usersim_default_key
import copy
class DBQuery:
"""Queries the database for the state tracker."""
def __init__(self, database):
"""
The constructor for DBQuery.
Parameters:
database (dict): The database in the format dict(long: dict)
"""
self.database = database
# {frozenset: {string: int}} A dict of dicts
self.cached_db_slot = defaultdict(dict)
# {frozenset: {'#': {'slot': 'value'}}} A dict of dicts of dicts, a dict of DB sub-dicts
self.cached_db = defaultdict(dict)
self.no_query = no_query_keys
self.match_key = usersim_default_key
def fill_inform_slot(self, inform_slot_to_fill, current_inform_slots):
"""
Given the current informs/constraints fill the informs that need to be filled with values from the database.
Searches through the database to fill the inform slots with PLACEHOLDER with values that work given the current
constraints of the current episode.
Parameters:
inform_slot_to_fill (dict): Inform slots to fill with values
current_inform_slots (dict): Current inform slots with values from the StateTracker
Returns:
dict: inform_slot_to_fill filled with values
"""
# For this simple system only one inform slot should ever passed in
assert len(inform_slot_to_fill) == 1
key = list(inform_slot_to_fill.keys())[0]
# This removes the inform we want to fill from the current informs if it is present in the current informs
# so it can be re-queried
current_informs = copy.deepcopy(current_inform_slots)
current_informs.pop(key, None)
# db_results is a dict of dict in the same exact format as the db, it is just a subset of the db
db_results = self.get_db_results(current_informs)
filled_inform = {}
values_dict = self._count_slot_values(key, db_results)
if values_dict:
# Get key with max value (ie slot value with highest count of available results)
filled_inform[key] = max(values_dict, key=values_dict.get)
else:
filled_inform[key] = 'no match available'
return filled_inform
def _count_slot_values(self, key, db_subdict):
"""
Return a dict of the different values and occurrences of each, given a key, from a sub-dict of database
Parameters:
key (string): The key to be counted
db_subdict (dict): A sub-dict of the database
Returns:
dict: The values and their occurrences given the key
"""
slot_values = defaultdict(int) # init to 0
for id in db_subdict.keys():
current_option_dict = db_subdict[id]
# If there is a match
if key in current_option_dict.keys():
slot_value = current_option_dict[key]
# This will add 1 to 0 if this is the first time this value has been encountered, or it will add 1
# to whatever was already in there
slot_values[slot_value] += 1
return slot_values
def get_db_results(self, constraints):
"""
Get all items in the database that fit the current constraints.
Looks at each item in the database and if its slots contain all constraints and their values match then the item
is added to the return dict.
Parameters:
constraints (dict): The current informs
Returns:
dict: The available items in the database
"""
# Filter non-queryable items and keys with the value 'anything' since those are inconsequential to the constraints
new_constraints = {k: v for k, v in constraints.items() if k not in self.no_query and v is not 'anything'}
inform_items = frozenset(new_constraints.items())
cache_return = self.cached_db[inform_items]
if cache_return == None:
# If it is none then no matches fit with the constraints so return an empty dict
return {}
# if it isnt empty then return what it is
if cache_return:
return cache_return
# else continue on
available_options = {}
for id in self.database.keys():
current_option_dict = self.database[id]
# First check if that database item actually contains the inform keys
# Note: this assumes that if a constraint is not found in the db item then that item is not a match
if len(set(new_constraints.keys()) - set(self.database[id].keys())) == 0:
match = True
# Now check all the constraint values against the db values and if there is a mismatch don't store
for k, v in new_constraints.items():
if str(v).lower() != str(current_option_dict[k]).lower():
match = False
if match:
# Update cache
self.cached_db[inform_items].update({id: current_option_dict})
available_options.update({id: current_option_dict})
# if nothing available then set the set of constraint items to none in cache
if not available_options:
self.cached_db[inform_items] = None
return available_options
def get_db_results_for_slots(self, current_informs):
"""
Counts occurrences of each current inform slot (key and value) in the database items.
For each item in the database and each current inform slot if that slot is in the database item (matches key
and value) then increment the count for that key by 1.
Parameters:
current_informs (dict): The current informs/constraints
Returns:
dict: Each key in current_informs with the count of the number of matches for that key
"""
# The items (key, value) of the current informs are used as a key to the cached_db_slot
inform_items = frozenset(current_informs.items())
# A dict of the inform keys and their counts as stored (or not stored) in the cached_db_slot
cache_return = self.cached_db_slot[inform_items]
if cache_return:
return cache_return
# If it made it down here then a new query was made and it must add it to cached_db_slot and return it
# Init all key values with 0
db_results = {key: 0 for key in current_informs.keys()}
db_results['matching_all_constraints'] = 0
for id in self.database.keys():
all_slots_match = True
for CI_key, CI_value in current_informs.items():
# Skip if a no query item and all_slots_match stays true
if CI_key in self.no_query:
continue
# If anything all_slots_match stays true AND the specific key slot gets a +1
if CI_value == 'anything':
db_results[CI_key] += 1
continue
if CI_key in self.database[id].keys():
if CI_value.lower() == self.database[id][CI_key].lower():
db_results[CI_key] += 1
else:
all_slots_match = False
else:
all_slots_match = False
if all_slots_match: db_results['matching_all_constraints'] += 1
# update cache (set the empty dict)
self.cached_db_slot[inform_items].update(db_results)
assert self.cached_db_slot[inform_items] == db_results
return db_results