Skip to content

Commit

Permalink
Query param to enable AND filtering for facets
Browse files Browse the repository at this point in the history
  • Loading branch information
pudo committed Jan 13, 2025
1 parent 3f61b61 commit 23f4407
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 22 deletions.
14 changes: 14 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def test_search_filter_countries_remove():
assert len(results) == 0, results


def test_search_filter_countries_operator():
res = client.get("/search/default?q=vladimir putin&countries=ke&countries=ru")
assert res.status_code == 200, res
results = res.json()["results"]
assert len(results) > 0, results

res = client.get(
"/search/default?q=vladimir putin&filter_op=and&countries=ke&countries=ru"
)
assert res.status_code == 200, res
results = res.json()["results"]
assert len(results) == 0, results


def test_search_facet_datasets_default():
res = client.get("/search/default")
assert res.status_code == 200, res
Expand Down
1 change: 1 addition & 0 deletions yente/routers/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def match(
fuzzy: bool = Query(
settings.MATCH_FUZZY,
title="Use slow matching for candidate generation, does not affect scores",
deprecated=True,
),
changed_since: Optional[str] = Query(
None,
Expand Down
32 changes: 24 additions & 8 deletions yente/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from yente.provider import SearchProvider, get_provider
from yente.search.queries import parse_sorts, text_query
from yente.search.queries import facet_aggregations
from yente.search.queries import FilterDict
from yente.search.queries import FilterDict, Operator
from yente.search.search import get_entity, search_entities
from yente.search.search import result_entities, result_facets, result_total
from yente.search.nested import serialize_entity
Expand All @@ -28,7 +28,6 @@ class Facet(str, enum.Enum):
DATASETS = "datasets"
SCHEMA = "schema"
COUNTRIES = "countries"
NAMES = "names"
IDENTIFIERS = "identifiers"
TOPICS = "topics"
GENDERS = "genders"
Expand All @@ -55,10 +54,13 @@ async def search(
settings.BASE_SCHEMA, title="Types of entities that can match the search"
),
include_dataset: List[str] = Query(
[], title="Only include the given datasets in results"
[],
title="Restrict the search scope to datasets (that are in the given scope) to search entities within.",
description="Limit the results to entities that are part of at least one of the given datasets.",
),
exclude_dataset: List[str] = Query(
[], title="Remove the given datasets from results"
[],
title="Remove specific datasets (that are in the given scope) from the search scope.",
),
exclude_schema: List[str] = Query(
[], title="Remove the given types of entities from results"
Expand All @@ -72,21 +74,34 @@ async def search(
topics: List[str] = Query(
[], title="Filter by entity topics (e.g. sanction, role.pep)"
),
datasets: List[str] = Query([], title="Use `include_dataset` instead"),
datasets: List[str] = Query(
[],
title="Filter by dataset names, for faceting use (respects operator choice).",
),
limit: int = Query(
settings.DEFAULT_PAGE, title="Number of results to return", le=settings.MAX_PAGE
),
offset: int = Query(
0, title="Start at result with given offset", le=settings.MAX_OFFSET
),
sort: List[str] = Query([], title="Sorting criteria"),
target: Optional[bool] = Query(None, title="Include only targeted entities"),
target: Optional[bool] = Query(
None,
title="Include only targeted entities",
description="Please specify a list of topics of concern, instead.",
deprecated=True,
),
fuzzy: bool = Query(False, title="Allow fuzzy query syntax"),
simple: bool = Query(False, title="Use simple syntax for user-facing query boxes"),
facets: List[Facet] = Query(
DEFAULT_FACETS,
title="Facet counts to include in response.",
),
filter_op: Operator = Query(
"OR",
title="Define behaviour of multiple filters on one field",
description="Logic to use when combining multiple filters on the same field (topics, countries, datasets). Please specify AND for new integrations (to override a legacy default) and when building a faceted user interface.",
),
provider: SearchProvider = Depends(get_provider),
) -> SearchResponse:
"""Search endpoint for matching entities based on a simple piece of text, e.g.
Expand All @@ -105,8 +120,8 @@ async def search(
filters: FilterDict = {
"countries": countries,
"topics": topics,
"datasets": datasets,
}
include_dataset.extend(datasets)
if target is not None:
filters["target"] = target
query = text_query(
Expand All @@ -117,9 +132,10 @@ async def search(
fuzzy=fuzzy,
simple=simple,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
exclude_schema=exclude_schema,
changed_since=changed_since,
filter_op=filter_op,
)
aggregations = facet_aggregations([f.value for f in facets])
resp = await search_entities(
Expand Down
2 changes: 1 addition & 1 deletion yente/routers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


PATH_DATASET = Path(
description="Data source or collection name to be queries",
description="Data source or collection name to scope the query to.",
examples=["default"],
)
QUERY_PREFIX = Query("", min_length=0, description="Search prefix")
Expand Down
41 changes: 28 additions & 13 deletions yente/search/queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
from pprint import pprint # noqa
from typing import Any, Dict, Generator, List, Tuple, Union, Optional
from typing import Any, Dict, Generator, List, Set, Tuple, Union, Optional
from followthemoney.schema import Schema
from followthemoney.proxy import EntityProxy
from followthemoney.types import registry
Expand All @@ -16,31 +17,38 @@
Clause = Dict[str, Any]


class Operator(str, enum.Enum):
AND = "AND"
OR = "OR"


def filter_query(
shoulds: List[Clause],
dataset: Optional[Dataset] = None,
scope_dataset: Optional[Dataset] = None,
schema: Optional[Schema] = None,
filters: FilterDict = {},
include_dataset: List[str] = [],
exclude_schema: List[str] = [],
exclude_dataset: List[str] = [],
changed_since: Optional[str] = None,
filter_op: Operator = Operator.AND,
) -> Clause:
filterqs: List[Clause] = []
must_not: List[Clause] = []
datasets: List[str] = include_dataset
if not len(datasets) and dataset is not None:
datasets = dataset.dataset_names
for exclude_ds in exclude_dataset:

datasets: Set[str] = set(scope_dataset.dataset_names)
if len(include_dataset):
datasets = datasets.intersection(include_dataset)
if len(exclude_dataset):
# This is logically a bit more consistent, but doesn't describe the use
# case of wanting to screen all the entities from datasets X, Y but not Z:
# must_not.append({"term": {"datasets": exclude_ds}})
if exclude_ds in datasets:
datasets.remove(exclude_ds)
datasets = datasets.difference(exclude_dataset)
if len(datasets):
filterqs.append({"terms": {"datasets": datasets}})
filterqs.append({"terms": {"datasets": list(datasets)}})
else:
filterqs.append({"match_none": {}})

if schema is not None:
schemata = schema.matchable_schemata
if not schema.matchable:
Expand All @@ -53,7 +61,12 @@ def filter_query(
continue
values = [v for v in values if len(v)]
if len(values):
filterqs.append({"terms": {field: values}})
if filter_op == Operator.OR:
filterqs.append({"terms": {field: values}})
continue
elif filter_op == Operator.AND:
for v in values:
filterqs.append({"term": {field: v}})
if changed_since is not None:
filterqs.append({"range": {"last_change": {"gt": changed_since}}})

Expand Down Expand Up @@ -121,7 +134,7 @@ def entity_query(
return filter_query(
shoulds,
filters=filters,
dataset=dataset,
scope_dataset=dataset,
schema=entity.schema,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
Expand All @@ -141,6 +154,7 @@ def text_query(
exclude_schema: List[str] = [],
exclude_dataset: List[str] = [],
changed_since: Optional[str] = None,
filter_op: Operator = Operator.AND,
) -> Clause:
if not len(query.strip()):
should: Clause = {"match_all": {}}
Expand Down Expand Up @@ -168,13 +182,14 @@ def text_query(
# log.info("Query", should=should)
return filter_query(
[should],
dataset=dataset,
scope_dataset=dataset,
schema=schema,
filters=filters,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
changed_since=changed_since,
filter_op=filter_op,
)


Expand All @@ -186,7 +201,7 @@ def prefix_query(
should: Clause = {"match_none": {}}
else:
should = {"match_phrase_prefix": {"names": {"query": prefix, "slop": 2}}}
return filter_query([should], dataset=dataset)
return filter_query([should], scope_dataset=dataset)


def facet_aggregations(fields: List[str] = []) -> Clause:
Expand Down

0 comments on commit 23f4407

Please sign in to comment.