Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable progressive proxy via flag #512

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions jupyter_server_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,27 @@ def cats_only(response, path):
""",
).tag(config=True)

progressive = Union(
[Bool(), Callable()],
default_value=None,
allow_none=True,
help="""
Makes the proxy progressive, meaning it won't buffer any requests from the server.
Useful for applications streaming their data, where the buffering of requests can lead
to a lagging, e.g. in video streams.

Must be either None (default), a bool, or a function. Setting it to a boolean will enable/disable
progressive requests for all requests. Setting to None, jupyter-server-proxy will only enable progressive
for somespecial types, like videos, images and binary data. A function must be taking the "Accept" header of
the request from the client as input and returning a bool, whether this request should be made progressive.

Note: `progressive` and `rewrite_response` are mutually exclusive on the same request. When rewrite_response
is given and progressive is None, the proxying will never be progressive. If progressive is a function,
rewrite_response will only be called on requests where it returns False. Progressive takes precedence over
rewrite_response when both are given!
""",
).tag(config=True)

update_last_activity = Bool(
True, help="Will cause the proxy to report activity back to jupyter server."
).tag(config=True)
Expand Down Expand Up @@ -304,6 +325,7 @@ def __init__(self, *args, **kwargs):
self.unix_socket = sp.unix_socket
self.mappath = sp.mappath
self.rewrite_response = sp.rewrite_response
self.progressive = sp.progressive
self.update_last_activity = sp.update_last_activity

def get_request_headers_override(self):
Expand Down
53 changes: 47 additions & 6 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from traitlets.traitlets import HasTraits

from .unixsock import UnixResolver
from .utils import call_with_asked_args
from .utils import call_with_asked_args, mime_types_match
from .websocket import WebSocketHandlerMixin, pingable_ws_connect


Expand Down Expand Up @@ -95,6 +95,15 @@ def get(self, *args):
self.redirect(urlunparse(dest))


COMMON_BINARY_MIME_TYPES = [
"image/*",
"audio/*",
"video/*",
"application/*",
"text/event-stream",
]


class ProxyHandler(WebSocketHandlerMixin, JupyterHandler):
"""
A tornado request handler that proxies HTTP and websockets from
Expand All @@ -117,10 +126,41 @@ def __init__(self, *args, **kwargs):
"rewrite_response",
tuple(),
)
self.progressive = kwargs.pop("progressive", None)
self._requested_subprotocols = None
self.update_last_activity = kwargs.pop("update_last_activity", True)
super().__init__(*args, **kwargs)

@property
def progressive(self):
accept_header = self.request.headers.get("Accept")

if self._progressive is not None:
if callable(self._progressive):
return self._progressive(accept_header)
else:
return self._progressive

# Progressive and RewritableResponse are mutually exclusive
if self.rewrite_response:
return False

if accept_header is None:
return False

# If the client can accept multiple types, we will not make the request progressive
if "," in accept_header:
return False

return any(
mime_types_match(pattern, accept_header)
for pattern in COMMON_BINARY_MIME_TYPES
)

@progressive.setter
def progressive(self, value):
self._progressive = value

# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
# to enable cross origin requests propagated by e.g. inverting proxies.

Expand Down Expand Up @@ -376,16 +416,16 @@ async def proxy(self, host, port, proxied_path):
)
else:
client = httpclient.AsyncHTTPClient(force_instance=True)
# check if the request is stream request
accept_header = self.request.headers.get("Accept")
if accept_header == "text/event-stream":

if self.progressive:
return await self._proxy_progressive(host, port, proxied_path, body, client)
else:
return await self._proxy_buffered(host, port, proxied_path, body, client)

async def _proxy_progressive(self, host, port, proxied_path, body, client):
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
# Set up handlers so we can progressively flush result
self.log.debug(f"Request to '{proxied_path}' will be proxied progressive")

headers_raw = []

Expand Down Expand Up @@ -466,9 +506,10 @@ def streaming_callback(chunk):
self.write(response.body)

async def _proxy_buffered(self, host, port, proxied_path, body, client):
req = self._build_proxy_request(host, port, proxied_path, body)
self.log.debug(f"Request to '{proxied_path}' will be proxied buffered")

self.log.debug(f"Proxying request to {req.url}")
req = self._build_proxy_request(host, port, proxied_path, body)
self.log.debug(f"Proxy request URL: {req.url}")

try:
# Here, "response" is a tornado.httpclient.HTTPResponse object.
Expand Down
17 changes: 17 additions & 0 deletions jupyter_server_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,20 @@ def call_with_asked_args(callback, args):
)
)
return callback(*asked_arg_values)


def mime_types_match(pattern: str, value: str) -> bool:
"""
Compare a MIME type pattern, possibly with wildcards, and a value
"""
value = value.split(";")[0] # Remove optional details
if pattern == value:
return True

type, subtype = value.split("/")
pattern = pattern.split("/")

if pattern[0] == "*" or (pattern[0] == type and pattern[1] == "*"):
return True

return False
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,28 @@ def _test_func(a, b):
return c

assert utils.call_with_asked_args(_test_func, {"a": 5, "b": 4, "c": 8}) == 20


def test_mime_types_match():
# Exact match
assert utils.mime_types_match("text/plain", "text/plain")
assert not utils.mime_types_match("text/plain", "text/html")

# With optional parameters
assert utils.mime_types_match("text/plain", "text/plain;charset=UTF-8")
assert not utils.mime_types_match("text/plain", "text/html;charset=UTF-8")

# With a single widcard
assert utils.mime_types_match("*", "text/plain")
assert utils.mime_types_match("*", "text/plain;charset=UTF-8")

# With both components wildcard
assert utils.mime_types_match("*/*", "text/plain")
assert utils.mime_types_match("*/*", "text/plain;charset=UTF-8")

# With a subtype wildcard
assert utils.mime_types_match("text/*", "text/plain")
assert not utils.mime_types_match("image/*", "text/plain")

assert utils.mime_types_match("text/*", "text/plain;charset=UTF-8")
assert not utils.mime_types_match("image/*", "text/plain;charset=UTF-8")