Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/databricks/labs/mcp/servers/unity_catalog/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class CliSettings(BaseSettings):
vector_search_num_results: int = Field(
default=5,
description="Number of results to return from vector search queries",
validation_alias=AliasChoices("vn", "vector_search_num_results", "vector_num_results"),
validation_alias=AliasChoices(
"vn", "vector_search_num_results", "vector_num_results"
),
)

def get_catalog_name(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ class QueryInput(BaseModel):


class VectorSearchTool(BaseTool):
def __init__(self, endpoint_name: str, index_name: str, tool_name: str, columns: list[str], num_results: int = 5):
def __init__(
self,
endpoint_name: str,
index_name: str,
tool_name: str,
columns: list[str],
num_results: int = 5,
):
self.endpoint_name = endpoint_name
self.index_name = index_name
self.tool_name = tool_name
Expand Down Expand Up @@ -46,17 +53,20 @@ def execute(self, **kwargs):
return [TextContent(type="text", text=json.dumps(docs, indent=2))]


def get_table_columns(workspace_client: WorkspaceClient, full_table_name: str) -> list[str]:
def get_table_columns(
workspace_client: WorkspaceClient, full_table_name: str
) -> list[str]:
table_info = workspace_client.tables.get(full_table_name)
return [
col.name
for col in table_info.columns
if col.name != CONTENT_VECTOR_COLUMN_NAME
col.name for col in table_info.columns if col.name != CONTENT_VECTOR_COLUMN_NAME
]


def _list_vector_search_tools(
workspace_client: WorkspaceClient, catalog_name: str, schema_name: str, vector_search_num_results: int
workspace_client: WorkspaceClient,
catalog_name: str,
schema_name: str,
vector_search_num_results: int,
) -> list[VectorSearchTool]:
tools = []
for table in workspace_client.tables.list(
Expand All @@ -71,12 +81,18 @@ def _list_vector_search_tools(

columns = get_table_columns(workspace_client, index_name)

tools.append(VectorSearchTool(endpoint, index_name, tool_name, columns, vector_search_num_results))
tools.append(
VectorSearchTool(
endpoint, index_name, tool_name, columns, vector_search_num_results
)
)

return tools


def list_vector_search_tools(settings: CliSettings) -> list[VectorSearchTool]:
workspace_client = WorkspaceClient()
catalog_name, schema_name = settings.schema_full_name.split(".")
return _list_vector_search_tools(workspace_client, catalog_name, schema_name, settings.vector_search_num_results)
return _list_vector_search_tools(
workspace_client, catalog_name, schema_name, settings.vector_search_num_results
)
15 changes: 11 additions & 4 deletions tests/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def __init__(self, name):
self.name = name

class DummyTableInfo:
columns = [DummyColumn("col1"), DummyColumn("col2"), DummyColumn("__db_content_vector")]
columns = [
DummyColumn("col1"),
DummyColumn("col2"),
DummyColumn("__db_content_vector"),
]

return DummyTableInfo()

Expand All @@ -41,6 +45,7 @@ def __init__(self):

class DummySettings:
schema_full_name = "cat.sch"
vector_search_num_results = 5


@mock.patch(
Expand All @@ -59,14 +64,16 @@ def test_list_vector_search_tools_filters_and_returns_expected():

def test_internal_list_vector_search_tools_direct():
client = DummyWorkspaceClient()
tools = _list_vector_search_tools(client, "cat", "sch")
tools = _list_vector_search_tools(client, "cat", "sch", vector_search_num_results=5)
assert len(tools) == 1
assert isinstance(tools[0], VectorSearchTool)
assert tools[0].index_name == "cat.sch.tbl1"
assert tools[0].columns == ["col1", "col2"]


@mock.patch("databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient")
@mock.patch(
"databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient"
)
def test_vector_search_tool_execute(MockVectorSearchClient):
mock_index = mock.Mock()
mock_index.similarity_search.return_value = {
Expand All @@ -87,4 +94,4 @@ def test_vector_search_tool_execute(MockVectorSearchClient):

assert isinstance(result, list)
assert result[0].text.strip().startswith("[") # It should be JSON string
assert "score" in result[0].text
assert "score" in result[0].text