From f1ca6f717a6d73554212589c0651c0ec47827304 Mon Sep 17 00:00:00 2001 From: Jack Boylan <70636379+jackboyla@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:41:29 +0300 Subject: [PATCH] Aggregation functions (`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`) (#45) * Adds support for multigraphs * Refactors `_is_edge_attr_match` * Filters relations by __label__ during `_lookup` * Bundles relation attributes together for lookup * Refactors and adds inline docs * Adds tests for multigraph support * Cleans up inline docs * Removes slicing list twice to avoid two copies in memory * Supports WHERE clause for relationships in multigraphs * Adds test for multigraph with WHERE clause on single edge * Accounts for WHERE with string node attributes in MultiDiGraphs * Unifies all unit tests to work with both DiGraphs and MultiDiGraphs * Completes multidigraph test for WHERE on node attribute * Supports logical OR for relationship matching * Adds tests for logical OR in MATCH for relationships * Implements aggregation functions * Removes unused code * Adds agg function results to `_return_requests` * Handles `None` values appropriately for MIN and MAX * Adds tests for agg functions and adjusts existing tests to new output * Adds examples page * Adds test for multiple agg functions * Removes commented code --- README.md | 3 + examples.md | 66 +++++++++++++++ grandcypher/__init__.py | 92 +++++++++++++++++++-- grandcypher/test_queries.py | 156 ++++++++++++++++++++++++++++++++---- 4 files changed, 296 insertions(+), 21 deletions(-) create mode 100644 examples.md diff --git a/README.md b/README.md index 07a583f..c82ab92 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ RETURN A.club, B.club """) ``` +See [examples.md](examples.md) for more! + ### Example Usage with SQL Create your own "Sqlite for Neo4j"! This example uses [grand-graph](https://github.com/aplbrain/grand) to run queries in SQL: @@ -81,6 +83,7 @@ RETURN | Graph mutations (e.g. `DELETE`, `SET`,...) | 🛣 | | | `DISTINCT` | ✅ Thanks @jackboyla! | | | `ORDER BY` | ✅ Thanks @jackboyla! | | +| Aggregation functions (`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`) | ✅ Thanks @jackboyla! | | | | | | | -------------- | -------------- | ---------------- | diff --git a/examples.md b/examples.md new file mode 100644 index 0000000..0122329 --- /dev/null +++ b/examples.md @@ -0,0 +1,66 @@ + +## Multigraph + +```python +from grandcypher import GrandCypher +import networkx as nx + +host = nx.MultiDiGraph() +host.add_node("a", name="Alice", age=25) +host.add_node("b", name="Bob", age=30) +host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") +host.add_edge("b", "a", __labels__={"paid"}, amount=6) +host.add_edge("b", "a", __labels__={"paid"}, value=14) +host.add_edge("a", "b", __labels__={"friends"}, years=9) +host.add_edge("a", "b", __labels__={"paid"}, amount=40) + +qry = """ +MATCH (n)-[r:paid]->(m) +RETURN n.name, m.name, r.amount +""" +res = GrandCypher(host).run(qry) +print(res) + +''' +{ + 'n.name': ['Alice', 'Bob'], + 'm.name': ['Bob', 'Alice'], + 'r.amount': [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}] +} +''' +``` + +## Aggregation Functions + +```python +from grandcypher import GrandCypher +import networkx as nx + +host = nx.MultiDiGraph() +host.add_node("a", name="Alice", age=25) +host.add_node("b", name="Bob", age=30) +host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") +host.add_edge("b", "a", __labels__={"paid"}, amount=6) +host.add_edge("b", "a", __labels__={"paid"}, value=14) +host.add_edge("a", "b", __labels__={"friends"}, years=9) +host.add_edge("a", "b", __labels__={"paid"}, amount=40) + +qry = """ +MATCH (n)-[r:paid]->(m) +RETURN n.name, m.name, SUM(r.amount) +""" +res = GrandCypher(host).run(qry) +print(res) + +''' +{ + 'n.name': ['Alice', 'Bob'], + 'm.name': ['Bob', 'Alice'], + 'SUM(r.amount)': [{'paid': 52, 'friends': 0}, {'paid': 6}] +} +''' +``` + + + + diff --git a/grandcypher/__init__.py b/grandcypher/__init__.py index 4c3259b..e582af2 100644 --- a/grandcypher/__init__.py +++ b/grandcypher/__init__.py @@ -81,7 +81,13 @@ -return_clause : "return"i distinct_return? entity_id ("," entity_id)* + +return_clause : "return"i distinct_return? return_item ("," return_item)* +return_item : entity_id | aggregation_function | entity_id "." attribute_id + +aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")" +AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN" +attribute_id : CNAME distinct_return : "DISTINCT"i limit_clause : "limit"i NUMBER @@ -282,6 +288,7 @@ def _get_entity_from_host( edge_data = host.get_edge_data(*entity_name) if not edge_data: return None # print(f"Nothing found for {entity_name} {entity_attribute}") + if entity_attribute: # looking for edge attribute: if isinstance(host, nx.MultiDiGraph): @@ -376,6 +383,7 @@ def __init__(self, target_graph: nx.Graph, limit=None): self._matche_paths = None self._return_requests = [] self._return_edges = {} + self._aggregate_functions = [] self._distinct = False self._order_by = None self._order_by_attributes = set() @@ -483,9 +491,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: for r in ret: r_attr = {} for i, v in r.items(): - r_attr[i] = v.get(entity_attribute, None) + r_attr[(i, list(v.get('__labels__'))[0])] = v.get(entity_attribute, None) + # eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}] ret_with_attr.append(r_attr) - + ret = ret_with_attr result[data_path] = list(ret)[offset_limit] @@ -497,9 +506,19 @@ def return_clause(self, clause): # collect all entity identifiers to be returned for item in clause: if item: - if not isinstance(item, str): - item = str(item.value) - self._return_requests.append(item) + item = item.children[0] if isinstance(item, Tree) else item + if isinstance(item, Tree) and item.data == "aggregation_function": + func = str(item.children[0].value) # AGGREGATE_FUNC + entity = str(item.children[1].value) + if len(item.children) > 2: + entity += "." + str(item.children[2].children[0].value) + self._aggregate_functions.append((func, entity)) + self._return_requests.append(entity) + else: + if not isinstance(item, str): + item = str(item.value) + self._return_requests.append(item) + def order_clause(self, order_clause): self._order_by = [] @@ -525,12 +544,73 @@ def skip_clause(self, skip): skip = int(skip[-1]) self._skip = skip + + def aggregate(self, func, results, entity, group_keys): + # Collect data based on group keys + grouped_data = {} + for i in range(len(results[entity])): + group_tuple = tuple(results[key][i] for key in group_keys if key in results) + if group_tuple not in grouped_data: + grouped_data[group_tuple] = [] + grouped_data[group_tuple].append(results[entity][i]) + + def _collate_data(data, unique_labels, func): + # for ["COUNT", "SUM", "AVG"], we treat None as 0 + if func in ["COUNT", "SUM", "AVG"]: + collated_data = { + label: [(v or 0) for rel in data for k, v in rel.items() if k[1] == label] for label in unique_labels + } + # for ["MAX", "MIN"], we treat None as non-existent + elif func in ["MAX", "MIN"]: + collated_data = { + label: [v for rel in data for k, v in rel.items() if (k[1] == label and v is not None)] for label in unique_labels + } + + return collated_data + + # Apply aggregation function + aggregate_results = {} + for group, data in grouped_data.items(): + # data => [{(0, 'paid'): 70, (1, 'paid'): 90}] + unique_labels = set([k[1] for rel in data for k in rel.keys()]) + collated_data = _collate_data(data, unique_labels, func) + if func == "COUNT": + count_data = {label: len(data) for label, data in collated_data.items()} + aggregate_results[group] = count_data + elif func == "SUM": + sum_data = {label: sum(data) for label, data in collated_data.items()} + aggregate_results[group] = sum_data + elif func == "AVG": + sum_data = {label: sum(data) for label, data in collated_data.items()} + count_data = {label: len(data) for label, data in collated_data.items()} + avg_data = {label: sum_data[label] / count_data[label] if count_data[label] > 0 else 0 for label in sum_data} + aggregate_results[group] = avg_data + elif func == "MAX": + max_data = {label: max(data) for label, data in collated_data.items()} + aggregate_results[group] = max_data + elif func == "MIN": + min_data = {label: min(data) for label, data in collated_data.items()} + aggregate_results[group] = min_data + + aggregate_results = [v for v in aggregate_results.values()] + return aggregate_results + def returns(self, ignore_limit=False): results = self._lookup( self._return_requests + list(self._order_by_attributes), offset_limit=slice(0, None), ) + if len(self._aggregate_functions) > 0: + group_keys = [key for key in results.keys() if not any(key.endswith(func[1]) for func in self._aggregate_functions)] + + aggregated_results = {} + for func, entity in self._aggregate_functions: + aggregated_data = self.aggregate(func, results, entity, group_keys) + func_key = f"{func}({entity})" + aggregated_results[func_key] = aggregated_data + self._return_requests.append(func_key) + results.update(aggregated_results) if self._order_by: results = self._apply_order_by(results) if self._distinct: diff --git a/grandcypher/test_queries.py b/grandcypher/test_queries.py index 9fc0bfb..afea528 100644 --- a/grandcypher/test_queries.py +++ b/grandcypher/test_queries.py @@ -909,8 +909,8 @@ def test_multiple_edges_specific_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice"] assert res["b.name"] == ["Bob"] - assert res["r.years"] == [{0: 3, 1: 5, 2: None}] # should return None when attr is missing - + assert res["r.years"] == [{(0, 'colleague'): 3, (1, 'friend'): 5, (2, 'enemy'): None}] # should return None when attr is missing + def test_edge_directionality(self): host = nx.MultiDiGraph() host.add_node("a", name="Alice", age=25) @@ -926,9 +926,8 @@ def test_edge_directionality(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Bob", "Alice"] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}] - + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2, (1, 'mentor'): 4}] def test_query_with_missing_edge_attribute(self): host = nx.MultiDiGraph() @@ -947,7 +946,7 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Charlie", "Charlie"] - assert res["r.duration"] == [{0: None}, {0: 10, 1: None}] # should return None when attr is missing + assert res["r.duration"] == [{(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}] qry = """ MATCH (a)-[r:colleague]->(b) @@ -956,7 +955,7 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Charlie", "Charlie"] - assert res["r.years"] == [{0: 10}, {0: None, 1: 2}] + assert res["r.years"] == [{(0, 'colleague'): 10}, {(0, 'colleague'): None, (1, 'mentor'): 2}] qry = """ MATCH (a)-[r]->(b) @@ -965,8 +964,8 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ['Alice', 'Alice', 'Bob'] assert res["b.name"] == ['Bob', 'Charlie', 'Charlie'] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.duration"] == [{0: None}, {0: None}, {0: 10, 1: None}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.duration"] == [{(0, 'friend'): None}, {(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}] def test_multigraph_single_edge_where(self): host = nx.MultiDiGraph() @@ -986,9 +985,9 @@ def test_multigraph_single_edge_where(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Bob", "Alice"] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}] - assert res["r.friendly"] == [{0: 'very'}, {0: None, 1: None}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2, (1, 'mentor'): 4}] + assert res["r.friendly"] == [{(0, 'friend'): 'very'}, {(0, 'colleague'): None, (1, 'mentor'): None}] def test_multigraph_where_node_attribute(self): host = nx.MultiDiGraph() @@ -1008,9 +1007,136 @@ def test_multigraph_where_node_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice"] assert res["b.name"] == ["Bob"] - assert res["r.__labels__"] == [{0: {'friend'}}] - assert res["r.years"] == [{0: 1}] - assert res["r.friendly"] == [{0: 'very'}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}] + assert res["r.years"] == [{(0, 'friend'): 1}] + assert res["r.friendly"] == [{(0, 'friend'): 'very'}] + + def test_multigraph_multiple_same_edge_labels(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"friends"}, years=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, r.amount + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["Alice", "Bob"] + assert res["m.name"] == ["Bob", "Alice"] + # the second "paid" edge between Bob -> Alice has no "amount" attribute, so it should be None + assert res["r.amount"] == [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}] + + def test_multigraph_aggregation_function_sum(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"friends"}, years=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, SUM(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res['SUM(r.amount)'] == [{'friends': 0, 'paid': 52}, {'paid': 6}] + + def test_multigraph_aggregation_function_avg(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6, message="Thanks") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, AVG(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["AVG(r.amount)"] == [{'paid': 26}, {'paid': 6}] + + def test_multigraph_aggregation_function_min(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("a", "b", __labels__={"paid"}, value=4) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, MIN(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MIN(r.amount)"] == [{'paid': 12}, {'paid': 6}] + + def test_multigraph_aggregation_function_max(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, MAX(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MAX(r.amount)"] == [{'paid': 40}, {'paid': 6}] + + qry = """ + MATCH (n)-[r:owes]->(m) + RETURN n.name, m.name, MAX(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MAX(r.amount)"] == [{'owes': 39}] + + def test_multigraph_aggregation_function_count(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, COUNT(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["COUNT(r.amount)"] == [{'paid': 2}, {'paid': 1}] + + def test_multigraph_multiple_aggregation_functions(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, COUNT(r.amount), SUM(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["COUNT(r.amount)"] == [{'paid': 2}, {'paid': 1}] + assert res["SUM(r.amount)"] == [{'paid': 52}, {'paid': 6}] class TestVariableLengthRelationship: