diff --git a/src/databricks/labs/mcp/servers/unity_catalog/cli.py b/src/databricks/labs/mcp/servers/unity_catalog/cli.py index 96a8c83..2d0b067 100644 --- a/src/databricks/labs/mcp/servers/unity_catalog/cli.py +++ b/src/databricks/labs/mcp/servers/unity_catalog/cli.py @@ -35,7 +35,9 @@ class CliSettings(BaseSettings): vector_search_num_results: int = Field( default=5, description="Number of results to return from vector search queries", - validation_alias=AliasChoices("vn", "vector_search_num_results", "vector_num_results"), + validation_alias=AliasChoices( + "vn", "vector_search_num_results", "vector_num_results" + ), ) def get_catalog_name(self): diff --git a/src/databricks/labs/mcp/servers/unity_catalog/tools/vector_search.py b/src/databricks/labs/mcp/servers/unity_catalog/tools/vector_search.py index 6b9062f..c3c7ba7 100644 --- a/src/databricks/labs/mcp/servers/unity_catalog/tools/vector_search.py +++ b/src/databricks/labs/mcp/servers/unity_catalog/tools/vector_search.py @@ -15,7 +15,14 @@ class QueryInput(BaseModel): class VectorSearchTool(BaseTool): - def __init__(self, endpoint_name: str, index_name: str, tool_name: str, columns: list[str], num_results: int = 5): + def __init__( + self, + endpoint_name: str, + index_name: str, + tool_name: str, + columns: list[str], + num_results: int = 5, + ): self.endpoint_name = endpoint_name self.index_name = index_name self.tool_name = tool_name @@ -46,17 +53,20 @@ def execute(self, **kwargs): return [TextContent(type="text", text=json.dumps(docs, indent=2))] -def get_table_columns(workspace_client: WorkspaceClient, full_table_name: str) -> list[str]: +def get_table_columns( + workspace_client: WorkspaceClient, full_table_name: str +) -> list[str]: table_info = workspace_client.tables.get(full_table_name) return [ - col.name - for col in table_info.columns - if col.name != CONTENT_VECTOR_COLUMN_NAME + col.name for col in table_info.columns if col.name != CONTENT_VECTOR_COLUMN_NAME ] def _list_vector_search_tools( - workspace_client: WorkspaceClient, catalog_name: str, schema_name: str, vector_search_num_results: int + workspace_client: WorkspaceClient, + catalog_name: str, + schema_name: str, + vector_search_num_results: int, ) -> list[VectorSearchTool]: tools = [] for table in workspace_client.tables.list( @@ -71,7 +81,11 @@ def _list_vector_search_tools( columns = get_table_columns(workspace_client, index_name) - tools.append(VectorSearchTool(endpoint, index_name, tool_name, columns, vector_search_num_results)) + tools.append( + VectorSearchTool( + endpoint, index_name, tool_name, columns, vector_search_num_results + ) + ) return tools @@ -79,4 +93,6 @@ def _list_vector_search_tools( def list_vector_search_tools(settings: CliSettings) -> list[VectorSearchTool]: workspace_client = WorkspaceClient() catalog_name, schema_name = settings.schema_full_name.split(".") - return _list_vector_search_tools(workspace_client, catalog_name, schema_name, settings.vector_search_num_results) \ No newline at end of file + return _list_vector_search_tools( + workspace_client, catalog_name, schema_name, settings.vector_search_num_results + ) diff --git a/tests/test_vector_search.py b/tests/test_vector_search.py index 7467f0a..0428d9b 100644 --- a/tests/test_vector_search.py +++ b/tests/test_vector_search.py @@ -29,7 +29,11 @@ def __init__(self, name): self.name = name class DummyTableInfo: - columns = [DummyColumn("col1"), DummyColumn("col2"), DummyColumn("__db_content_vector")] + columns = [ + DummyColumn("col1"), + DummyColumn("col2"), + DummyColumn("__db_content_vector"), + ] return DummyTableInfo() @@ -41,6 +45,7 @@ def __init__(self): class DummySettings: schema_full_name = "cat.sch" + vector_search_num_results = 5 @mock.patch( @@ -59,14 +64,16 @@ def test_list_vector_search_tools_filters_and_returns_expected(): def test_internal_list_vector_search_tools_direct(): client = DummyWorkspaceClient() - tools = _list_vector_search_tools(client, "cat", "sch") + tools = _list_vector_search_tools(client, "cat", "sch", vector_search_num_results=5) assert len(tools) == 1 assert isinstance(tools[0], VectorSearchTool) assert tools[0].index_name == "cat.sch.tbl1" assert tools[0].columns == ["col1", "col2"] -@mock.patch("databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient") +@mock.patch( + "databricks.labs.mcp.servers.unity_catalog.tools.vector_search.VectorSearchClient" +) def test_vector_search_tool_execute(MockVectorSearchClient): mock_index = mock.Mock() mock_index.similarity_search.return_value = { @@ -87,4 +94,4 @@ def test_vector_search_tool_execute(MockVectorSearchClient): assert isinstance(result, list) assert result[0].text.strip().startswith("[") # It should be JSON string - assert "score" in result[0].text \ No newline at end of file + assert "score" in result[0].text