Skip to content

Commit

Permalink
Add support for inspect.iscoroutinefunction() in Coroutine provider
Browse files Browse the repository at this point in the history
  • Loading branch information
ZipFile committed Nov 5, 2024
1 parent abf2a25 commit ba3f0dd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/dependency_injector/providers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,22 @@ except ImportError:
# Python 2.7
import __builtin__ as builtins

try:
from inspect import _is_coroutine_marker
except ImportError:
_is_coroutine_marker = True

try:
import asyncio
except ImportError:
asyncio = None
_is_coroutine_marker = None
_is_coroutine = None
else:
if sys.version_info >= (3, 5, 3):
import asyncio.coroutines
_is_coroutine_marker = asyncio.coroutines._is_coroutine
_is_coroutine = asyncio.coroutines._is_coroutine
else:
_is_coroutine_marker = True
_is_coroutine = True

try:
import ConfigParser as iniconfigparser
Expand Down Expand Up @@ -1434,7 +1439,8 @@ cdef class Coroutine(Callable):
some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4)
"""

_is_coroutine = _is_coroutine_marker
_is_coroutine_marker = _is_coroutine_marker # Python >=3.12
_is_coroutine = _is_coroutine # Python <3.16

def set_provides(self, provides):
"""Set provider provides."""
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/providers/coroutines/test_coroutine_py35.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Coroutine provider tests."""
import sys

from dependency_injector import providers, errors
from pytest import mark, raises
Expand Down Expand Up @@ -208,3 +209,17 @@ def test_repr():
"<dependency_injector.providers."
"Coroutine({0}) at {1}>".format(repr(example), hex(id(provider)))
)


@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16")
def test_asyncio_iscoroutinefunction() -> None:
from asyncio.coroutines import iscoroutinefunction

assert iscoroutinefunction(providers.Coroutine(example))


@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12")
def test_inspect_iscoroutinefunction() -> None:
from inspect import iscoroutinefunction

assert iscoroutinefunction(providers.Coroutine(example))

0 comments on commit ba3f0dd

Please sign in to comment.