From 3c71b9ae64789d0b14d76423c8f73a3e7caf9676 Mon Sep 17 00:00:00 2001 From: Jonathan de Bruin Date: Mon, 23 Oct 2023 17:26:16 +0200 Subject: [PATCH] Refactor pagination and add support for offset pagination (#31) --- README.md | 31 +++-- pyalex/api.py | 265 ++++++++++++++++++++++--------------------- tests/test_paging.py | 94 +++++++++++++++ tests/test_pyalex.py | 81 ------------- 4 files changed, 251 insertions(+), 220 deletions(-) create mode 100644 tests/test_paging.py diff --git a/README.md b/README.md index 9291981..fb4dee1 100644 --- a/README.md +++ b/README.md @@ -260,20 +260,14 @@ Works().filter(institutions={"country_code": "fr|gb"}).get() #### Paging -OpenAlex offers two methods for paging: [basic paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#basic-paging) and [cursor paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#cursor-paging). Both methods are supported by -PyAlex, although cursor paging seems to be easier to implement and less error-prone. +OpenAlex offers two methods for paging: [basic (offset) paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#basic-paging) and [cursor paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#cursor-paging). Both methods are supported by PyAlex. -##### Basic paging - -See limitations of [basic paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#basic-paging) in the OpenAlex documentation. -It's relatively easy to implement basic paging with PyAlex, however it is -advised to use the built-in pager based on cursor paging. - -##### Cursor paging +##### Cursor paging (default) -Use `paginate()` for paging results. Each page is a list of records, with a -maximum of `per_page` (default 25). By default, `paginate`s argument `n_max` -is set to 10000. Use `None` to retrieve all results. +Use the method `paginate()` to paginate results. Each returned page is a list +of records, with a maximum of `per_page` (default 25). By default, +`paginate`s argument `n_max` is set to 10000. Use `None` to retrieve all +results. ```python from pyalex import Authors @@ -296,6 +290,19 @@ for record in chain(*query.paginate(per_page=200)): print(record["id"]) ``` +##### Basic paging + +See limitations of [basic paging](https://docs.openalex.org/how-to-use-the-api/get-lists-of-entities/paging#basic-paging) in the OpenAlex documentation. + +```python +from pyalex import Authors + +pager = Authors().search_filter(display_name="einstein").paginate(method="page", per_page=200) + +for page in pager: + print(len(page)) +``` + ### Get N-grams diff --git a/pyalex/api.py b/pyalex/api.py index 5b2649d..053e110 100644 --- a/pyalex/api.py +++ b/pyalex/api.py @@ -76,13 +76,7 @@ def _params_merge(params, add_params): params[k] = add_params[k] -def invert_abstract(inv_index): - if inv_index is not None: - l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] - return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) - - -def get_requests_session(): +def _get_requests_session(): # create an Requests Session with automatic retry: requests_session = requests.Session() retries = Retry( @@ -98,109 +92,75 @@ def get_requests_session(): return requests_session -class QueryError(ValueError): - pass - - -class OpenAlexEntity(dict): - pass - - -class Work(OpenAlexEntity): - """OpenAlex work object.""" - - def __getitem__(self, key): - if key == "abstract": - return invert_abstract(self["abstract_inverted_index"]) - - return super().__getitem__(key) - - def ngrams(self, return_meta=False): - openalex_id = self["id"].split("/")[-1] - - res = get_requests_session().get( - f"{config.openalex_url}/works/{openalex_id}/ngrams", - headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, - ) - res.raise_for_status() - results = res.json() - - # return result and metadata - if return_meta: - return results["ngrams"], results["meta"] - else: - return results["ngrams"] - - -class Author(OpenAlexEntity): - pass - - -class Source(OpenAlexEntity): - pass - - -class Institution(OpenAlexEntity): - pass - - -class Concept(OpenAlexEntity): - pass +def invert_abstract(inv_index): + if inv_index is not None: + l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] + return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) -class Publisher(OpenAlexEntity): +class QueryError(ValueError): pass -class Funder(OpenAlexEntity): +class OpenAlexEntity(dict): pass -# deprecated - - -def Venue(*args, **kwargs): - # warn about deprecation - warnings.warn( - "Venue is deprecated. Use Sources instead.", - DeprecationWarning, - stacklevel=2, - ) - - return Source(*args, **kwargs) - +class Paginator: + VALUE_CURSOR_START = "*" + VALUE_NUMBER_START = 1 -class CursorPaginator: - def __init__(self, alex_class=None, per_page=None, cursor="*", n_max=None): - self.alex_class = alex_class + def __init__( + self, endpoint_class, method="cursor", value=None, per_page=None, n_max=None + ): + self.method = method + self.endpoint_class = endpoint_class + self.value = value self.per_page = per_page - self.cursor = cursor self.n_max = n_max + self._next_value = value + def __iter__(self): self.n = 0 return self - def __next__(self): + def _is_max(self): if self.n_max and self.n >= self.n_max: + return True + return False + + def __next__(self): + if self._next_value is None or self._is_max(): raise StopIteration - r, m = self.alex_class.get( - return_meta=True, per_page=self.per_page, cursor=self.cursor + if self.method == "cursor": + pagination_params = {"cursor": self._next_value} + elif self.method == "page": + pagination_params = {"page": self._next_value} + else: + raise ValueError() + + results, meta = self.endpoint_class.get( + return_meta=True, per_page=self.per_page, **pagination_params ) - if m["next_cursor"] is None: - raise StopIteration + if self.method == "cursor": + self._next_value = meta["next_cursor"] + + if self.method == "page": + if len(results) > 0: + self._next_value = meta["page"] + 1 + else: + self._next_value = None - self.n = self.n + len(r) - self.cursor = m["next_cursor"] + self.n = self.n + len(results) - return r + return results class BaseOpenAlex: - """Base class for OpenAlex objects.""" def __init__(self, params=None): @@ -230,17 +190,9 @@ def __getitem__(self, record_id): if isinstance(record_id, list): return self._get_multi_items(record_id) - url = self._full_collection_name() + "/" + record_id - params = {"api_key": config.api_key} if config.api_key else {} - res = get_requests_session().get( - url, - headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, - params=params, + return self._get_from_url( + self._full_collection_name() + "/" + record_id, return_meta=False ) - res.raise_for_status() - res_json = res.json() - - return self.resource_class(res_json) @property def url(self): @@ -269,38 +221,35 @@ def count(self): return m["count"] - def get(self, return_meta=False, page=None, per_page=None, cursor=None): - if per_page is not None and (per_page < 1 or per_page > 200): - raise ValueError("per_page should be a number between 1 and 200.") - - self._add_params("per-page", per_page) - self._add_params("page", page) - self._add_params("cursor", cursor) - + def _get_from_url(self, url, return_meta=False): params = {"api_key": config.api_key} if config.api_key else {} - res = get_requests_session().get( - self.url, + + res = _get_requests_session().get( + url, headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, params=params, ) # handle query errors if res.status_code == 403: - res_json = res.json() if ( - isinstance(res_json["error"], str) - and "query parameters" in res_json["error"] + isinstance(res.json()["error"], str) + and "query parameters" in res.json()["error"] ): - raise QueryError(res_json["message"]) - res.raise_for_status() + raise QueryError(res.json()["message"]) + res.raise_for_status() res_json = res.json() # group-by or results page - if "group-by" in self.params: + if self.params and "group-by" in self.params: results = res_json["group_by"] - else: + elif "results" in res_json: results = [self.resource_class(ent) for ent in res_json["results"]] + elif "id" in res_json: + results = self.resource_class(res_json) + else: + raise ValueError("Unknown response format") # return result and metadata if return_meta: @@ -308,24 +257,27 @@ def get(self, return_meta=False, page=None, per_page=None, cursor=None): else: return results - def paginate(self, per_page=None, cursor="*", n_max=10000): - """Used for paging results of large responses using cursor paging. + def get(self, return_meta=False, page=None, per_page=None, cursor=None): + if per_page is not None and (per_page < 1 or per_page > 200): + raise ValueError("per_page should be a number between 1 and 200.") + + self._add_params("per-page", per_page) + self._add_params("page", page) + self._add_params("cursor", cursor) - OpenAlex offers two methods for paging: basic paging and cursor paging. - Both methods are supported by PyAlex, although cursor paging seems to be - easier to implement and less error-prone. + return self._get_from_url(self.url, return_meta=return_meta) - Args: - per_page (_type_, optional): Entries per page to return. Defaults to None. - cursor (str, optional): _description_. Defaults to "*". - n_max (int, optional): Number of max results (not pages) to return. - Defaults to 10000. + def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=10000): + if method == "cursor": + value = cursor + elif method == "page": + value = page + else: + raise ValueError("Method should be 'cursor' or 'page'") - Returns: - CursorPaginator: Iterator to use for returning and processing each page - result in sequence. - """ - return CursorPaginator(self, per_page=per_page, cursor=cursor, n_max=n_max) + return Paginator( + self, method=method, value=value, per_page=per_page, n_max=n_max + ) def random(self): return self.__getitem__("random") @@ -370,38 +322,97 @@ def select(self, s): return self +# The API + + +class Work(OpenAlexEntity): + def __getitem__(self, key): + if key == "abstract": + return invert_abstract(self["abstract_inverted_index"]) + + return super().__getitem__(key) + + def ngrams(self, return_meta=False): + openalex_id = self["id"].split("/")[-1] + + res = _get_requests_session().get( + f"{config.openalex_url}/works/{openalex_id}/ngrams", + headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, + ) + res.raise_for_status() + results = res.json() + + # return result and metadata + if return_meta: + return results["ngrams"], results["meta"] + else: + return results["ngrams"] + + class Works(BaseOpenAlex): resource_class = Work +class Author(OpenAlexEntity): + pass + + class Authors(BaseOpenAlex): resource_class = Author +class Source(OpenAlexEntity): + pass + + class Sources(BaseOpenAlex): resource_class = Source +class Institution(OpenAlexEntity): + pass + + class Institutions(BaseOpenAlex): resource_class = Institution +class Concept(OpenAlexEntity): + pass + + class Concepts(BaseOpenAlex): resource_class = Concept +class Publisher(OpenAlexEntity): + pass + + class Publishers(BaseOpenAlex): resource_class = Publisher +class Funder(OpenAlexEntity): + pass + + class Funders(BaseOpenAlex): resource_class = Funder -# deprecated +def Venue(*args, **kwargs): # deprecated + # warn about deprecation + warnings.warn( + "Venue is deprecated. Use Sources instead.", + DeprecationWarning, + stacklevel=2, + ) + + return Source(*args, **kwargs) -def Venues(*args, **kwargs): +def Venues(*args, **kwargs): # deprecated # warn about deprecation warnings.warn( "Venues is deprecated. Use Sources instead.", diff --git a/tests/test_paging.py b/tests/test_paging.py new file mode 100644 index 0000000..2e52c4b --- /dev/null +++ b/tests/test_paging.py @@ -0,0 +1,94 @@ +from pyalex import Authors +from pyalex.api import Paginator + + +def test_cursor(): + query = Authors().search_filter(display_name="einstein") + + # store the results + results = [] + + next_cursor = "*" + + # loop till next_cursor is None + while next_cursor is not None: + # get the results + r, m = query.get(return_meta=True, per_page=200, cursor=next_cursor) + + # results + results.extend(r) + + # set the next cursor + next_cursor = m["next_cursor"] + + assert len(results) > 200 + + +def test_page(): + query = Authors().search_filter(display_name="einstein") + + # set the page + page = 1 + + # store the results + results = [] + + # loop till page is None + while page is not None: + # get the results + r, m = query.get(return_meta=True, per_page=200, page=page) + + # results + results.extend(r) + page = None if len(r) == 0 else m["page"] + 1 + + assert len(results) > 200 + + +def test_paginate_counts(): + _, m = Authors().search_filter(display_name="einstein").get(return_meta=True) + + p_default = Authors().search_filter(display_name="einstein").paginate(per_page=200) + n_p_default = sum(len(page) for page in p_default) + + p_cursor = ( + Authors() + .search_filter(display_name="einstein") + .paginate(method="cursor", per_page=200) + ) + n_p_cursor = sum(len(page) for page in p_cursor) + + p_page = ( + Authors() + .search_filter(display_name="einstein") + .paginate(method="page", per_page=200) + ) + n_p_page = sum(len(page) for page in p_page) + + assert m["count"] == n_p_page >= n_p_default == n_p_cursor + + +def test_paginate_instance(): + p_default = Authors().search_filter(display_name="einstein").paginate(per_page=200) + assert isinstance(p_default, Paginator) + assert p_default.method == "cursor" + + +def test_paginate_cursor_n_max(): + p = ( + Authors() + .search_filter(display_name="einstein") + .paginate(per_page=200, n_max=400) + ) + + assert sum(len(page) for page in p) == 400 + + +def test_cursor_paging_n_max_none(): + p = ( + Authors() + .search_filter(display_name="einstein") + .paginate(per_page=200, n_max=None) + ) + + sum(len(page) for page in p) diff --git a/tests/test_pyalex.py b/tests/test_pyalex.py index 1002ba6..f3d5867 100644 --- a/tests/test_pyalex.py +++ b/tests/test_pyalex.py @@ -181,87 +181,6 @@ def test_search_filter(): assert r["meta"]["count"] == m["count"] -def test_cursor_by_hand(): - # example query - query = Authors().search_filter(display_name="einstein") - - # store the results - results = [] - - next_cursor = "*" - - # loop till next_cursor is None - while next_cursor is not None: - # get the results - r, m = query.get(return_meta=True, per_page=200, cursor=next_cursor) - - # results - results.extend(r) - - # set the next cursor - next_cursor = m["next_cursor"] - - assert len(results) > 200 - - -def test_basic_paging(): - # example query - query = Authors().search_filter(display_name="einstein") - - # set the page - page = 1 - - # store the results - results = [] - - # loop till page is None - while page is not None: - # get the results - r, m = query.get(return_meta=True, per_page=200, page=page) - - # results - results.extend(r) - page = None if len(r) == 0 else m["page"] + 1 - - assert len(results) > 200 - - -def test_cursor_paging(): - # example query - pager = Authors().search_filter(display_name="einstein").paginate(per_page=200) - - for page in pager: - assert len(page) >= 1 and len(page) <= 200 - - -def test_cursor_paging_n_max(): - # example query - pager = ( - Authors() - .search_filter(display_name="einstein") - .paginate(per_page=200, n_max=400) - ) - - n = 0 - for page in pager: - n = n + len(page) - - assert n == 400 - - -def test_cursor_paging_n_max_none(): - # example query - pager = ( - Authors() - .search_filter(display_name="einstein") - .paginate(per_page=200, n_max=None) - ) - - n = 0 - for page in pager: - n = n + len(page) - - def test_referenced_works(): # the work to extract the referenced works of w = Works()["W2741809807"]