114 lines
3.3 KiB
Python
Executable file
114 lines
3.3 KiB
Python
Executable file
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextvars
|
|
from typing import Any, Union, TYPE_CHECKING, Generic
|
|
from typing_extensions import TypeAlias, TypeVar, Self
|
|
import sys
|
|
|
|
if sys.version_info >= (3, 9):
|
|
from collections.abc import Generator, Coroutine, Awaitable
|
|
else:
|
|
from typing import Generator, Coroutine, Awaitable
|
|
|
|
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
|
|
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
|
|
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
|
|
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
|
|
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)
|
|
|
|
_T = TypeVar("_T")
|
|
_T_co = TypeVar("_T_co", covariant=True)
|
|
_TaskYieldType: TypeAlias = "asyncio.Future[object] | None"
|
|
|
|
if sys.version_info >= (3, 12):
|
|
_TaskCompatibleCoro: TypeAlias = Coroutine[Any, Any, _T_co]
|
|
elif sys.version_info >= (3, 9):
|
|
_TaskCompatibleCoro: TypeAlias = Union[
|
|
Generator[_TaskYieldType, None, _T_co],
|
|
Coroutine[Any, Any, _T_co],
|
|
]
|
|
else:
|
|
_TaskCompatibleCoro: TypeAlias = (
|
|
"Generator[_TaskYieldType, None, _T_co] | Awaitable[_T_co]"
|
|
)
|
|
|
|
|
|
class _Interceptor(
|
|
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
|
|
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
|
|
):
|
|
def __init__(
|
|
self,
|
|
coro: (
|
|
Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
|
|
| Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
|
|
),
|
|
context: contextvars.Context,
|
|
):
|
|
self.__coro = coro
|
|
self.__context = context
|
|
|
|
def send(self, v: _SendT_contra_nd) -> _YieldT_co:
|
|
return self.__context.run(self.__coro.send, v)
|
|
|
|
def throw(self, *exc_info) -> _YieldT_co:
|
|
return self.__context.run(self.__coro.throw, *exc_info)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.__coro, name)
|
|
|
|
def __await__(self) -> Self:
|
|
return self
|
|
|
|
def close(self) -> None:
|
|
super().close()
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
class _Task(asyncio.Task[_T_co]):
|
|
pass
|
|
|
|
|
|
elif sys.version_info >= (3, 9):
|
|
_Task = asyncio.Task
|
|
else:
|
|
|
|
class _Task(asyncio.Task, Generic[_T_co]):
|
|
pass
|
|
|
|
|
|
class Task(_Task[_T_co]):
|
|
def __init__(
|
|
self, coro: _TaskCompatibleCoro[_T_co], *args, context=None, **kwargs
|
|
) -> None:
|
|
self._num_cancels_requested = 0
|
|
if context is not None:
|
|
coro = _Interceptor(coro, context)
|
|
super().__init__(coro, *args, **kwargs) # type: ignore
|
|
|
|
def cancel(self, *args: Any, **kwargs: Any) -> bool:
|
|
if not self.done():
|
|
self._num_cancels_requested += 1
|
|
return super().cancel(*args, **kwargs)
|
|
|
|
def cancelling(self) -> int:
|
|
return self._num_cancels_requested
|
|
|
|
def uncancel(self) -> int:
|
|
if self._num_cancels_requested > 0:
|
|
self._num_cancels_requested -= 1
|
|
return self._num_cancels_requested
|
|
|
|
def get_coro(self) -> _TaskCompatibleCoro[_T_co] | None:
|
|
coro = super().get_coro()
|
|
if isinstance(coro, _Interceptor):
|
|
return coro._Interceptor__coro # type: ignore
|
|
return coro
|
|
|
|
|
|
def task_factory(
|
|
loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any
|
|
) -> Task[_T_co]:
|
|
return Task(coro, loop=loop, **kwargs)
|