Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept AsyncIterables being passed to Response #341

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/quart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterable,
Iterator,
TYPE_CHECKING,
TypeVar,
)

from werkzeug.datastructures import Headers
Expand Down Expand Up @@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
return _wrapper


def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]:
async def _gen_wrapper() -> AsyncGenerator[Any, None]:
T = TypeVar("T")


def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]:
async def _gen_wrapper() -> AsyncIterator[T]:
# Wrap the generator such that each iteration runs
# in the executor. Then rationalise the raised
# errors so that it ends.
def _inner() -> Any:
def _inner() -> T:
# https://bugs.python.org/issue26221
# StopIteration errors are swallowed by the
# run_in_exector method
Expand Down
25 changes: 9 additions & 16 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from abc import ABC, abstractmethod
from hashlib import md5
from inspect import isasyncgen, isgenerator
from io import BytesIO
from os import PathLike
from types import TracebackType
Expand Down Expand Up @@ -102,27 +101,21 @@ async def __anext__(self) -> bytes:


class IterableBody(ResponseBody):
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable)
def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None:
self.iter: AsyncIterator[Any]
if isinstance(iterable, Iterable):
self.iter = run_sync_iterable(iter(iterable))
else:

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
yield data

self.iter = _aiter()
self.iter = iterable.__aiter__() # Can't use aiter() until 3.10

async def __aenter__(self) -> IterableBody:
return self

async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
await self.iter.aclose()
if hasattr(self.iter, "aclose"): # Is a generator?
await self.iter.aclose()

def __aiter__(self) -> AsyncIterator:
def __aiter__(self) -> AsyncIterator[Any]:
return self.iter


Expand Down Expand Up @@ -262,7 +255,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
g,
Quart,
render_template_string,
Response,
ResponseReturnValue,
session,
stream_template_string,
Expand Down Expand Up @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
test_client = app.test_client()
response = await test_client.get("/")
assert (await response.data) == b"42"

@app.get("/2")
async def index2() -> ResponseReturnValue:
return Response(await stream_template_string("{{ config }}", config=43))

test_client = app.test_client()
response = await test_client.get("/2")
assert (await response.data) == b"43"