Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions kernel_gateway/notebook_http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,16 @@ def create_request_handlers(self):
handlers.append((path, tornado.web.StaticFileHandler, {"path": self.static_path}))

# Discover the notebook endpoints and their implementations
endpoints = self.api_parser.endpoints(self.parent.kernel_manager.seed_source)
response_sources = self.api_parser.endpoint_responses(
self.parent.kernel_manager.seed_source
)
endpoints = self.api_parser.endpoints()
response_sources = self.api_parser.endpoint_responses()
if len(endpoints) == 0:
raise RuntimeError(
"No endpoints were discovered. Check your notebook to make sure your cells are annotated correctly."
)

# Cycle through the (endpoint_path, source) tuples and register their handlers
for endpoint_path, verb_source_map in endpoints:
description = verb_source_map.pop("__description__", "")
parameterized_path = parameterize_path(endpoint_path)
parameterized_path = url_path_join("/", self.parent.base_url, parameterized_path)
self.log.info(
Expand All @@ -134,6 +133,7 @@ def create_request_handlers(self):
"kernel_pool": self.kernel_pool,
"kernel_name": self.parent.kernel_manager.seed_kernelspec,
"kernel_language": self.kernel_language or "",
"description": description,
}
handlers.append((parameterized_path, NotebookAPIHandler, handler_args))

Expand Down
49 changes: 36 additions & 13 deletions kernel_gateway/notebook_http/cell/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys

import markdown
from traitlets import Unicode
from traitlets.config.configurable import LoggingConfigurable

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, comment_prefix, notebook_cells=None, **kwargs):
self.kernelspec_api_response_indicator = re.compile(
self.api_response_indicator.format(comment_prefix)
)
self.notebook_cells = notebook_cells or []

def is_api_cell(self, cell_source):
"""Gets if the cell source is annotated as an API endpoint.
Expand Down Expand Up @@ -151,13 +153,29 @@ def get_path_content(self, cell_source):
"""
return {"responses": {200: {"description": "Success"}}}

def endpoints(self, source_cells, sort_func=first_path_param_index):
def render_markdown_cell(self, cell_source):
"""Renders a markdown cell as HTML.

Parameters
----------
cell_source
Source from a notebook cell

Returns
-------
str
HTML representation of the markdown cell
"""
source_lines = cell_source.split('\n')
if source_lines and self.is_api_cell(source_lines[0]):
source_lines.pop(0)
return markdown.markdown('\n'.join(source_lines))

def endpoints(self, sort_func=first_path_param_index):
"""Gets the list of all annotated endpoint HTTP paths and verbs.

Parameters
----------
source_cells
List of source strings from notebook cells
sort_func
Function by which to sort the endpoint list

Expand All @@ -169,19 +187,24 @@ def endpoints(self, source_cells, sort_func=first_path_param_index):
element of each tuple
"""
endpoints = {}
for cell_source in source_cells:
if self.is_api_cell(cell_source):
matched = self.kernelspec_api_indicator.match(cell_source)
for cell in self.notebook_cells:
if self.is_api_cell(cell.source):
matched = self.kernelspec_api_indicator.match(cell.source)
uri = matched.group(2).strip()
verb = matched.group(1)

endpoints.setdefault(uri, {}).setdefault(verb, "")
endpoints[uri][verb] += cell_source + "\n"
endpoints.setdefault(uri, {}).setdefault(verb, {})
if cell.cell_type == "markdown":
endpoints[uri][verb]['source'] = self.render_markdown_cell(cell.source)
endpoints[uri][verb]['cell_type'] = 'markdown'
else:
endpoints[uri][verb]['source'] = cell.source + "\n"
endpoints[uri][verb]['cell_type'] = 'code'

sorted_keys = sorted(endpoints, key=sort_func, reverse=True)
return [(key, endpoints[key]) for key in sorted_keys]

def endpoint_responses(self, source_cells, sort_func=first_path_param_index):
def endpoint_responses(self, sort_func=first_path_param_index):
"""Gets the list of all annotated ResponseInfo HTTP paths and verbs.

Parameters
Expand All @@ -199,14 +222,14 @@ def endpoint_responses(self, source_cells, sort_func=first_path_param_index):
element of each tuple
"""
endpoints = {}
for cell_source in source_cells:
if self.is_api_response_cell(cell_source):
matched = self.kernelspec_api_response_indicator.match(cell_source)
for cell in self.notebook_cells:
if self.is_api_response_cell(cell.source):
matched = self.kernelspec_api_response_indicator.match(cell.source)
uri = matched.group(2).strip()
verb = matched.group(1)

endpoints.setdefault(uri, {}).setdefault(verb, "")
endpoints[uri][verb] += cell_source + "\n"
endpoints[uri][verb] += cell.source + "\n"
return endpoints

def get_default_api_spec(self):
Expand Down
53 changes: 50 additions & 3 deletions kernel_gateway/notebook_http/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,39 @@ class NotebookAPIHandler(
are identified, parsed, and associated with HTTP verbs and paths.
"""

def initialize(self, sources, response_sources, kernel_pool, kernel_name, kernel_language=""):
def initialize(self, sources, response_sources, kernel_pool, kernel_name, kernel_language="", description=""):
self.kernel_pool = kernel_pool
self.sources = sources
self.kernel_name = kernel_name
self.response_sources = response_sources
self.kernel_language = kernel_language
self.description = description

def _accumulate_display(self, results):
"""Accumulates result chunks for "display" messages and prepares them as

in-line <img> tags. See
https://ipython.org/ipython-doc/3/development/messaging.html#display-data
for details on the display protocol.

Parameters
----------
results: list
A list of results containing display data.
"""
out = []
for result in results:
if "image/png" in result:
out.append(
'<img alt="%s" src="data:image/png;base64,%s" />'
% (result.get("text/plain", ""), result["image/png"])
)
continue
if "text/html" in result:
out.append(result["text/html"])
if "text/plain" in result:
out.append(result["text/plain"])
return out

def finish_future(self, future, result_accumulator):
"""Resolves the promise to respond to a HTTP request handled by a
Expand All @@ -84,6 +111,10 @@ def finish_future(self, future, result_accumulator):
"""
if result_accumulator["error"]:
future.set_exception(CodeExecutionError(result_accumulator["error"]))
elif len(result_accumulator["display"]) > 0:
future.set_result(
"\n".join(self._accumulate_display(result_accumulator["display"]))
)
elif len(result_accumulator["stream"]) > 0:
future.set_result("".join(result_accumulator["stream"]))
elif result_accumulator["result"]:
Expand Down Expand Up @@ -123,6 +154,9 @@ def on_recv(self, result_accumulator, future, parent_header, msg):
# Store the execute result
elif msg["header"]["msg_type"] == "execute_result":
result_accumulator["result"] = msg["content"]["data"]
# Accumulate display data
elif msg['header']['msg_type'] == 'display_data':
result_accumulator['display'].append(msg['content']['data'])
# Accumulate the stream messages
elif msg["header"]["msg_type"] == "stream":
# Only take stream output if it is on stdout or if the kernel
Expand Down Expand Up @@ -162,7 +196,12 @@ def execute_code(self, kernel_client, kernel_id, source_code):
If the kernel returns any error
"""
future = Future()
result_accumulator = {"stream": [], "error": None, "result": None}
result_accumulator = {
"display": [],
"stream": [],
"error": None,
"result": None,
}
parent_header = kernel_client.execute(source_code)
on_recv_func = partial(self.on_recv, result_accumulator, future, parent_header)
self.kernel_pool.on_recv(kernel_id, on_recv_func)
Expand All @@ -187,7 +226,15 @@ async def _handle_request(self):
self.set_status(200)

# Get the source to execute in response to this request
source_code = self.sources[self.request.method]
source_info = self.sources[self.request.method]
source_code = source_info['source']
cell_type = source_info['cell_type']

if cell_type == 'markdown':
self.set_header("Content-Type", "text/html")
self.write(source_code)
return

# Build the request dictionary
request = json.dumps(
{
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies = [
"requests>=2.31",
"tornado>=6.4",
"traitlets>=5.14.1",
"markdown>=3.3.4",
"ipykernel>=6.29.5",
]

[project.scripts]
Expand Down
128 changes: 20 additions & 108 deletions tests/notebook_http/cell/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,27 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
"""Tests for notebook cell parsing."""

import sys

import unittest
from unittest.mock import Mock
from kernel_gateway.notebook_http.cell.parser import APICellParser

class TestAPICellParser(unittest.TestCase):
def test_markdown_cells_are_rendered(self):
mock_markdown_cell = Mock()
mock_markdown_cell.cell_type = "markdown"
mock_markdown_cell.source = "# GET /try\n\nThis is a markdown cell."

class TestAPICellParser:
"""Unit tests the APICellParser class."""

def test_is_api_cell(self):
"""Parser should correctly identify annotated API cells."""
parser = APICellParser(comment_prefix="#")
assert parser.is_api_cell("# GET /yes"), "API cell was not detected"
assert parser.is_api_cell("no") is False, "API cell was not detected"

def test_endpoint_sort_default_strategy(self):
"""Parser should sort duplicate endpoint paths."""
source_cells = [
"# POST /:foo",
"# POST /hello/:foo",
"# GET /hello/:foo",
"# PUT /hello/world",
]
parser = APICellParser(comment_prefix="#")
endpoints = parser.endpoints(source_cells)
expected_values = ["/hello/world", "/hello/:foo", "/:foo"]

for index in range(len(expected_values)):
endpoint, _ = endpoints[index]
assert expected_values[index] == endpoint, "Endpoint was not found in expected order"

def test_endpoint_sort_custom_strategy(self):
"""Parser should sort duplicate endpoint paths using a custom sort
strategy.
"""
source_cells = ["# POST /1", "# POST /+", "# GET /a"]

def custom_sort_fun(endpoint):
_ = sys.maxsize
if endpoint.find("1") >= 0:
return 0
elif endpoint.find("a") >= 0:
return 1
else:
return 2

parser = APICellParser(comment_prefix="#")
endpoints = parser.endpoints(source_cells, custom_sort_fun)
expected_values = ["/+", "/a", "/1"]

for index in range(len(expected_values)):
endpoint, _ = endpoints[index]
assert expected_values[index] == endpoint, "Endpoint was not found in expected order"

def test_get_cell_endpoint_and_verb(self):
"""Parser should extract API endpoint and verb from cell annotations."""
parser = APICellParser(comment_prefix="#")
endpoint, verb = parser.get_cell_endpoint_and_verb("# GET /foo")
assert endpoint, "/foo" == "Endpoint was not extracted correctly"
assert verb, "GET" == "Endpoint was not extracted correctly"
endpoint, verb = parser.get_cell_endpoint_and_verb("# POST /bar/quo")
assert endpoint, "/bar/quo" == "Endpoint was not extracted correctly"
assert verb, "POST" == "Endpoint was not extracted correctly"

endpoint, verb = parser.get_cell_endpoint_and_verb("some regular code")
assert endpoint is None, "Endpoint was not extracted correctly"
assert verb is None, "Endpoint was not extracted correctly"

def test_endpoint_concatenation(self):
"""Parser should concatenate multiple cells with the same verb+path."""
source_cells = [
"# POST /foo/:bar",
"# POST /foo/:bar",
"# POST /foo",
"ignored",
"# GET /foo/:bar",
]
parser = APICellParser(comment_prefix="#")
endpoints = parser.endpoints(source_cells)
assert len(endpoints) == 2
# for ease of testing
endpoints = dict(endpoints)
assert len(endpoints["/foo"]) == 1
assert len(endpoints["/foo/:bar"]) == 2
assert endpoints["/foo"]["POST"] == "# POST /foo\n"
assert endpoints["/foo/:bar"]["POST"] == "# POST /foo/:bar\n# POST /foo/:bar\n"
assert endpoints["/foo/:bar"]["GET"] == "# GET /foo/:bar\n"
mock_code_cell = Mock()
mock_code_cell.cell_type = "code"
mock_code_cell.source = "# GET /hello\nprint('hello')"

def test_endpoint_response_concatenation(self):
"""Parser should concatenate multiple response cells with the same
verb+path.
"""
source_cells = [
"# ResponseInfo POST /foo/:bar",
"# ResponseInfo POST /foo/:bar",
"# ResponseInfo POST /foo",
"ignored",
"# ResponseInfo GET /foo/:bar",
notebook_cells = [
mock_markdown_cell,
mock_code_cell
]
parser = APICellParser(comment_prefix="#")
endpoints = parser.endpoint_responses(source_cells)
assert len(endpoints) == 2
# for ease of testing
endpoints = dict(endpoints)
assert len(endpoints["/foo"]) == 1
assert len(endpoints["/foo/:bar"]) == 2
assert endpoints["/foo"]["POST"] == "# ResponseInfo POST /foo\n"
assert (
endpoints["/foo/:bar"]["POST"]
== "# ResponseInfo POST /foo/:bar\n# ResponseInfo POST /foo/:bar\n"
)
assert endpoints["/foo/:bar"]["GET"] == "# ResponseInfo GET /foo/:bar\n"
parser = APICellParser(comment_prefix="#", notebook_cells=notebook_cells)
endpoints = dict(parser.endpoints())
self.assertEqual(len(endpoints), 2)
self.assertIn("/hello", endpoints)
self.assertIn("/try", endpoints)
self.assertEqual(endpoints["/try"]["GET"]['cell_type'], "markdown")
self.assertEqual(endpoints["/try"]["GET"]['source'], "<p>This is a markdown cell.</p>")