From fbdfbade3c33167687a2f975b50aa986a15aff31 Mon Sep 17 00:00:00 2001 From: Mengqin Shen Date: Thu, 11 Dec 2025 00:47:32 -0800 Subject: [PATCH] feat(py): add embedder reference support matching Go SDK Add EmbedderRef, EmbedderOptions, and EmbedderSupports types to enable referencing embedders with specific configurations, matching the Go SDK functionality. BREAKING CHANGE: define_embedder() now accepts 'options: EmbedderOptions' parameter instead of separate 'config_schema' parameter. Changes: - Add EmbedderRef, EmbedderOptions, EmbedderSupports types to embedding.py - Implement embed() method in Genkit class supporting EmbedderRef - Update define_embedder signature to accept EmbedderOptions parameter - Fix circular imports using TYPE_CHECKING and proper import organization - Add to_json_schema() conversions for all config_schema usages - Update all plugins (OpenAI, Google GenAI, Ollama) to use JSON schemas - Add comprehensive tests for embedder functionality - Update test assertions to match new schema conversion behavior --- py/packages/genkit/src/genkit/ai/_aio.py | 51 +++- py/packages/genkit/src/genkit/ai/_registry.py | 20 +- .../genkit/src/genkit/blocks/embedding.py | 63 ++++- .../tests/genkit/blocks/embedding_test.py | 225 +++++++++++++++++- .../plugins/compat_oai/openai_plugin.py | 15 +- .../src/genkit/plugins/google_genai/google.py | 40 ++-- .../google-genai/test/test_google_plugin.py | 55 +++-- .../src/genkit/plugins/ollama/plugin_api.py | 18 +- py/plugins/ollama/tests/test_plugin_api.py | 7 +- 9 files changed, 403 insertions(+), 91 deletions(-) diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 7d77d881d8..5f3c9e920a 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -26,7 +26,7 @@ class while customizing it with any plugins. from genkit.aio import Channel from genkit.blocks.document import Document -from genkit.blocks.embedding import EmbedRequest, EmbedResponse +from genkit.blocks.embedding import EmbedderRef from genkit.blocks.generate import ( StreamingCallback as ModelStreamingCallback, generate_action, @@ -39,6 +39,7 @@ class while customizing it with any plugins. from genkit.blocks.prompt import PromptConfig, to_generate_action_options from genkit.core.action import ActionRunContext from genkit.core.action.types import ActionKind +from genkit.core.typing import EmbedRequest, EmbedResponse from genkit.types import ( DocumentData, GenerationCommonConfig, @@ -295,10 +296,12 @@ def generate_stream( async def embed( self, - embedder: str | None = None, + embedder: str | EmbedderRef | None = None, documents: list[Document] | None = None, options: dict[str, Any] | None = None, ) -> EmbedResponse: + embedder_name: str + embedder_config: dict[str, Any] = {} """Calculates embeddings for documents. Args: @@ -309,9 +312,22 @@ async def embed( Returns: The generated response with embeddings. """ - embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder) + if isinstance(embedder, EmbedderRef): + embedder_name = embedder.name + embedder_config = embedder.config or {} + if embedder.version: + embedder_config['version'] = embedder.version # Handle version from ref + elif isinstance(embedder, str): + embedder_name = embedder + else: + # Handle case where embedder is None + raise ValueError('Embedder must be specified as a string name or an EmbedderRef.') - return (await embed_action.arun(EmbedRequest(input=documents, options=options))).response + # Merge options passed to embed() with config from EmbedderRef + final_options = {**(embedder_config or {}), **(options or {})} + embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) + + return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response async def retrieve( self, @@ -335,3 +351,30 @@ async def retrieve( retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever) return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response + + async def embed( + self, + embedder: str | EmbedderRef | None = None, + documents: list[Document] | None = None, + options: dict[str, Any] | None = None, + ) -> EmbedResponse: + embedder_name: str + embedder_config: dict[str, Any] = {} + + if isinstance(embedder, EmbedderRef): + embedder_name = embedder.name + embedder_config = embedder.config or {} + if embedder.version: + embedder_config['version'] = embedder.version # Handle version from ref + elif isinstance(embedder, str): + embedder_name = embedder + else: + # Handle case where embedder is None + raise ValueError('Embedder must be specified as a string name or an EmbedderRef.') + + # Merge options passed to embed() with config from EmbedderRef + final_options = {**(embedder_config or {}), **(options or {})} + + embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) + + return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index 4bb19db11e..a055bad30c 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -47,7 +47,7 @@ import structlog from pydantic import BaseModel -from genkit.blocks.embedding import EmbedderFn +from genkit.blocks.embedding import EmbedderFn, EmbedderOptions from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn from genkit.blocks.formats.types import FormatDef from genkit.blocks.model import ModelFn, ModelMiddleware @@ -458,8 +458,7 @@ def define_embedder( self, name: str, fn: EmbedderFn, - config_schema: BaseModel | dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, + options: EmbedderOptions | None = None, description: str | None = None, ) -> Action: """Define a custom embedder action. @@ -471,13 +470,20 @@ def define_embedder( metadata: Optional metadata for the model. description: Optional description for the embedder. """ - embedder_meta: dict[str, Any] = metadata if metadata else {} + embedder_meta: dict[str, Any] = {} + if options: + if options.label: + embedder_meta['embedder']['label'] = options.label + if options.dimensions: + embedder_meta['embedder']['dimensions'] = options.dimensions + if options.supports: + embedder_meta['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + if options.config_schema: + embedder_meta['embedder']['customOptions'] = to_json_schema(options.config_schema) + if 'embedder' not in embedder_meta: embedder_meta['embedder'] = {} - if config_schema: - embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema) - embedder_description = get_func_description(fn, description) return self.registry.register_action( name=name, diff --git a/py/packages/genkit/src/genkit/blocks/embedding.py b/py/packages/genkit/src/genkit/blocks/embedding.py index 582ec5f6ac..95bca321a2 100644 --- a/py/packages/genkit/src/genkit/blocks/embedding.py +++ b/py/packages/genkit/src/genkit/blocks/embedding.py @@ -16,29 +16,76 @@ """Embedding actions.""" -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import Any -from genkit.ai import ActionKind +from pydantic import BaseModel, ConfigDict, Field + from genkit.core.action import ActionMetadata +from genkit.core.action.types import ActionKind from genkit.core.schema import to_json_schema from genkit.core.typing import EmbedRequest, EmbedResponse -# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse] + +class EmbedderSupports(BaseModel): + """Embedder capability support.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + input: list[str] | None = None + multilingual: bool | None = None + + +class EmbedderOptions(BaseModel): + """Configuration options for an embedder.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + config_schema: dict[str, Any] | None = Field(None, alias='configSchema') + label: str | None = None + supports: EmbedderSupports | None = None + dimensions: int | None = None + + +class EmbedderRef(BaseModel): + """Reference to an embedder with configuration.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + name: str + config: Any | None = None + version: str | None = None + + EmbedderFn = Callable[[EmbedRequest], EmbedResponse] def embedder_action_metadata( name: str, - info: dict[str, Any] | None = None, - config_schema: Any | None = None, + options: EmbedderOptions | None = None, ) -> ActionMetadata: - """Generates an ActionMetadata for embedders.""" - info = info if info is not None else {} + options = options if options is not None else EmbedderOptions() + embedder_metadata_dict = {'embedder': {}} + + if options.label: + embedder_metadata_dict['embedder']['label'] = options.label + + embedder_metadata_dict['embedder']['dimensions'] = options.dimensions + + if options.supports: + embedder_metadata_dict['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + embedder_metadata_dict['embedder']['customOptions'] = options.config_schema if options.config_schema else None + return ActionMetadata( kind=ActionKind.EMBEDDER, name=name, input_json_schema=to_json_schema(EmbedRequest), output_json_schema=to_json_schema(EmbedResponse), - metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}}, + metadata=embedder_metadata_dict, ) + + +def create_embedder_ref(name: str, config: dict[str, Any] | None = None, version: str | None = None) -> EmbedderRef: + """Creates an EmbedderRef instance.""" + return EmbedderRef(name=name, config=config, version=version) diff --git a/py/packages/genkit/tests/genkit/blocks/embedding_test.py b/py/packages/genkit/tests/genkit/blocks/embedding_test.py index 0a54d6aa2e..04cd4e7f0c 100644 --- a/py/packages/genkit/tests/genkit/blocks/embedding_test.py +++ b/py/packages/genkit/tests/genkit/blocks/embedding_test.py @@ -16,19 +16,232 @@ """Tests for the action module.""" -from genkit.blocks.embedding import embedder_action_metadata -from genkit.core.action import ActionMetadata +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import BaseModel + +from genkit.ai._aio import Genkit +from genkit.blocks.document import Document +from genkit.blocks.embedding import ( + EmbedderOptions, + EmbedderRef, + EmbedderSupports, + create_embedder_ref, + embedder_action_metadata, +) +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionResponse +from genkit.core.schema import to_json_schema +from genkit.core.typing import Embedding, EmbedRequest, EmbedResponse def test_embedder_action_metadata(): - """Test for embedder_action_metadata.""" + """Test for embedder_action_metadata with basic options.""" + options = EmbedderOptions(label='Test Embedder', dimensions=128) action_metadata = embedder_action_metadata( name='test_model', - info={'label': 'test_label'}, - config_schema=None, + options=options, ) assert isinstance(action_metadata, ActionMetadata) assert action_metadata.input_json_schema is not None assert action_metadata.output_json_schema is not None - assert action_metadata.metadata == {'embedder': {'customOptions': None, 'label': 'test_label'}} + assert action_metadata.metadata == { + 'embedder': { + 'label': options.label, + 'dimensions': options.dimensions, + 'customOptions': None, + } + } + + +def test_embedder_action_metadata_with_supports_and_config_schema(): + """Test for embedder_action_metadata with supports and config_schema.""" + + class CustomConfig(BaseModel): + param1: str + param2: int + + options = EmbedderOptions( + label='Advanced Embedder', + dimensions=256, + supports=EmbedderSupports(input=['text', 'image']), + config_schema=to_json_schema(CustomConfig), + ) + action_metadata = embedder_action_metadata( + name='advanced_model', + options=options, + ) + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata['embedder']['label'] == 'Advanced Embedder' + assert action_metadata.metadata['embedder']['dimensions'] == options.dimensions + assert action_metadata.metadata['embedder']['supports'] == { + 'input': ['text', 'image'], + } + assert action_metadata.metadata['embedder']['customOptions'] == { + 'title': 'CustomConfig', + 'type': 'object', + 'properties': { + 'param1': {'title': 'Param1', 'type': 'string'}, + 'param2': {'title': 'Param2', 'type': 'integer'}, + }, + 'required': ['param1', 'param2'], + } + + +def test_embedder_action_metadata_no_options(): + """Test embedder_action_metadata when no options are provided.""" + action_metadata = embedder_action_metadata(name='default_model') + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata == {'embedder': {'customOptions': None, 'dimensions': None}} + + +def test_create_embedder_ref_basic(): + """Test basic creation of EmbedderRef.""" + ref = create_embedder_ref('my-embedder') + assert ref.name == 'my-embedder' + assert ref.config is None + assert ref.version is None + + +def test_create_embedder_ref_with_config(): + """Test creation of EmbedderRef with configuration.""" + config = {'temperature': 0.5, 'max_tokens': 100} + ref = create_embedder_ref('configured-embedder', config=config) + assert ref.name == 'configured-embedder' + assert ref.config == config + assert ref.version is None + + +def test_create_embedder_ref_with_version(): + """Test creation of EmbedderRef with a version.""" + ref = create_embedder_ref('versioned-embedder', version='v1.0') + assert ref.name == 'versioned-embedder' + assert ref.config is None + assert ref.version == 'v1.0' + + +def test_create_embedder_ref_with_config_and_version(): + """Test creation of EmbedderRef with both config and version.""" + config = {'task_type': 'retrieval'} + ref = create_embedder_ref('full-embedder', config=config, version='beta') + assert ref.name == 'full-embedder' + assert ref.config == config + assert ref.version == 'beta' + + +class MockGenkitRegistry: + """A mock registry to simulate action lookup.""" + + def __init__(self): + self.actions = {} + + def register_action(self, name, kind, fn, metadata, description): + mock_action = MagicMock(spec=Action) + mock_action.name = name + mock_action.kind = kind + mock_action.metadata = metadata + mock_action.description = description + + async def mock_arun_side_effect(request, *args, **kwargs): + # Call the actual (fake) embedder function directly + embed_response = await fn(request) + return ActionResponse(response=embed_response, trace_id='mock_trace_id') + + mock_action.arun = AsyncMock(side_effect=mock_arun_side_effect) + self.actions[(kind, name)] = mock_action + return mock_action + + def lookup_action(self, kind, name): + return self.actions.get((kind, name)) + + +@pytest.fixture +def mock_genkit_instance(): + """Fixture for a Genkit instance with a mock registry.""" + registry = MockGenkitRegistry() + genkit_instance = Genkit() + genkit_instance.registry = registry + return genkit_instance, registry + + +@pytest.mark.asyncio +async def test_embed_with_embedder_ref(mock_genkit_instance): + """Test the embed method using EmbedderRef.""" + genkit_instance, registry = mock_genkit_instance + + async def fake_embedder_fn(request: EmbedRequest) -> EmbedResponse: + return EmbedResponse(embeddings=[Embedding(embedding=[1.0, 2.0, 3.0])]) + + embedder_options = EmbedderOptions( + label='Fake Embedder', + dimensions=3, + supports=EmbedderSupports(input=['text']), + config_schema={'type': 'object', 'properties': {'param': {'type': 'string'}}}, + ) + registry.register_action( + name='my-plugin/my-embedder', + kind='embedder', + fn=fake_embedder_fn, + metadata=embedder_action_metadata('my-plugin/my-embedder', options=embedder_options).metadata, + description='A fake embedder for testing', + ) + embedder_ref = create_embedder_ref('my-plugin/my-embedder', config={'param': 'value'}, version='v1') + + documents = [Document.from_text('hello world')] + + response = await genkit_instance.embed( + embedder=embedder_ref, documents=documents, options={'additional_option': True} + ) + + assert response.embeddings[0].embedding == [1.0, 2.0, 3.0] + + embed_action = registry.lookup_action('embedder', 'my-plugin/my-embedder') + assert embed_action is not None + embed_action.arun.assert_called_once() + + called_request = embed_action.arun.call_args[0][0] + assert isinstance(called_request, EmbedRequest) + assert called_request.input == documents + # Check if config from EmbedderRef and options are merged correctly + assert called_request.options == {'param': 'value', 'additional_option': True, 'version': 'v1'} + + +@pytest.mark.asyncio +async def test_embed_with_string_name_and_options(mock_genkit_instance): + """Test the embed method using a string name for embedder and options.""" + genkit_instance, registry = mock_genkit_instance + + async def fake_embedder_fn(request: EmbedRequest) -> EmbedResponse: + return EmbedResponse(embeddings=[Embedding(embedding=[4.0, 5.0, 6.0])]) + + embedder_options = EmbedderOptions(label='Another Fake', dimensions=3) + registry.register_action( + name='another-embedder', + kind='embedder', + fn=fake_embedder_fn, + metadata=embedder_action_metadata('another-embedder', options=embedder_options).metadata, + description='Another fake embedder', + ) + + documents = [Document.from_text('test text')] + + response = await genkit_instance.embed( + embedder='another-embedder', documents=documents, options={'custom_setting': 'high'} + ) + + assert response.embeddings[0].embedding == [4.0, 5.0, 6.0] + embed_action = registry.lookup_action('embedder', 'another-embedder') + called_request = embed_action.arun.call_args[0][0] + assert called_request.options == {'custom_setting': 'high'} + + +@pytest.mark.asyncio +async def test_embed_missing_embedder_raises_error(mock_genkit_instance): + """Test that embedding with a missing embedder raises an error.""" + genkit_instance, _ = mock_genkit_instance + documents = [Document.from_text('some text')] + + with pytest.raises(ValueError, match='Embedder must be specified as a string name or an EmbedderRef.'): + await genkit_instance.embed(documents=documents) diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py index f7d9739254..19cd4ebcf8 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py @@ -25,7 +25,7 @@ from genkit.ai._plugin import Plugin from genkit.ai._registry import GenkitRegistry -from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata from genkit.blocks.model import model_action_metadata from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind @@ -188,17 +188,14 @@ def list_actions(self) -> list[ActionMetadata]: for model in models: _name = model.id if 'embed' in _name: + # Default embedder metadata for OpenAI embedding models actions.append( embedder_action_metadata( name=open_ai_name(_name), - config_schema=Embedding, - info={ - 'label': f'OpenAI Embedding - {_name}', - 'dimensions': None, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + label=f'OpenAI Embedding - {_name}', + supports=EmbedderSupports(input=['text']), + ), ) ) else: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 36433d9eeb..8f1b0ab0ea 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -24,10 +24,11 @@ import genkit.plugins.google_genai.constants as const from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin -from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata from genkit.blocks.model import model_action_metadata from genkit.core.action import ActionMetadata from genkit.core.registry import ActionKind +from genkit.core.schema import to_json_schema from genkit.plugins.google_genai.models.embedder import ( Embedder, GeminiEmbeddingModels, @@ -139,7 +140,7 @@ def initialize(self, ai: GenkitRegistry) -> None: name=googleai_name(version), fn=gemini_model.generate, metadata=gemini_model.metadata, - config_schema=GeminiConfigSchema, + # config_schema=GeminiConfigSchema, ) for version in GeminiEmbeddingModels: @@ -148,7 +149,6 @@ def initialize(self, ai: GenkitRegistry) -> None: name=googleai_name(version), fn=embedder.generate, metadata=default_embedder_info(version), - config_schema=EmbedContentConfig, ) def resolve_action( @@ -193,7 +193,7 @@ def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: name=googleai_name(_clean_name), fn=gemini_model.generate, metadata=gemini_model.metadata, - config_schema=GeminiConfigSchema, + # config_schema=GeminiConfigSchema, ) def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: @@ -215,7 +215,6 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: name=googleai_name(_clean_name), fn=embedder.generate, metadata=default_embedder_info(_clean_name), - config_schema=EmbedContentConfig, ) @cached_property @@ -237,16 +236,19 @@ def list_actions(self) -> list[ActionMetadata]: model_action_metadata( name=googleai_name(name), info=google_model_info(name).model_dump(), - config_schema=GeminiConfigSchema, ), ) if 'embedContent' in m.supported_actions: + embed_info = default_embedder_info(name) actions_list.append( embedder_action_metadata( name=googleai_name(name), - info=default_embedder_info(name), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), ) ) @@ -318,7 +320,7 @@ def initialize(self, ai: GenkitRegistry) -> None: name=vertexai_name(version), fn=gemini_model.generate, metadata=gemini_model.metadata, - config_schema=GeminiConfigSchema, + # config_schema=GeminiConfigSchema, ) for version in VertexEmbeddingModels: @@ -327,7 +329,7 @@ def initialize(self, ai: GenkitRegistry) -> None: name=vertexai_name(version), fn=embedder.generate, metadata=default_embedder_info(version), - config_schema=EmbedContentConfig, + # config_schema=to_json_schema(EmbedContentConfig), ) for version in ImagenVersion: @@ -336,7 +338,6 @@ def initialize(self, ai: GenkitRegistry) -> None: name=vertexai_name(version), fn=imagen_model.generate, metadata=imagen_model.metadata, - config_schema=GenerateImagesConfigOrDict, ) def resolve_action( @@ -377,18 +378,18 @@ def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: model_ref = vertexai_image_model_info(_clean_name) model = ImagenModel(_clean_name, self._client) IMAGE_SUPPORTED_MODELS[_clean_name] = model_ref - config_schema = GenerateImagesConfigOrDict + # config_schema = GenerateImagesConfigOrDict else: model_ref = google_model_info(_clean_name) model = GeminiModel(_clean_name, self._client, ai) SUPPORTED_MODELS[_clean_name] = model_ref - config_schema = GeminiConfigSchema + # config_schema = GeminiConfigSchema ai.define_model( name=vertexai_name(_clean_name), fn=model.generate, metadata=model.metadata, - config_schema=config_schema, + # config_schema=config_schema, ) def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: @@ -410,7 +411,6 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: name=vertexai_name(_clean_name), fn=embedder.generate, metadata=default_embedder_info(_clean_name), - config_schema=EmbedContentConfig, ) @cached_property @@ -428,11 +428,15 @@ def list_actions(self) -> list[ActionMetadata]: for m in self._client.models.list(): name = m.name.replace('publishers/google/models/', '') if 'embed' in name.lower(): + embed_info = default_embedder_info(name) actions_list.append( embedder_action_metadata( name=vertexai_name(name), - info=default_embedder_info(name), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), ) ) # List all the vertexai models for generate actions @@ -440,7 +444,7 @@ def list_actions(self) -> list[ActionMetadata]: model_action_metadata( name=vertexai_name(name), info=google_model_info(name).model_dump(), - config_schema=GeminiConfigSchema, + # config_schema=GeminiConfigSchema, ), ) diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index d621f422e6..57b96d9c73 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -28,9 +28,10 @@ import pytest from genkit.ai import Genkit, GENKIT_CLIENT_HEADER -from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.embedding import embedder_action_metadata, EmbedderOptions, EmbedderSupports from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind +from genkit.core.schema import to_json_schema from genkit.plugins.google_genai import ( GoogleAI, VertexAI, @@ -146,7 +147,6 @@ def test_googleai_initialize(): name=googleai_name(version), fn=ANY, metadata=ANY, - config_schema=GeminiConfigSchema, ) for version in GeminiEmbeddingModels: @@ -154,7 +154,6 @@ def test_googleai_initialize(): name=googleai_name(version), fn=ANY, metadata=ANY, - config_schema=EmbedContentConfig, ) @@ -219,7 +218,6 @@ def test_googleai__resolve_model( name=expected_model_name, fn=ANY, metadata=ANY, - config_schema=GeminiConfigSchema, ) assert key in SUPPORTED_MODELS @@ -247,7 +245,7 @@ def test_googleai__resolve_embedder( ) ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, fn=ANY, config_schema=EmbedContentConfig, metadata=default_embedder_info(clean_name) + name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name) ) @@ -275,22 +273,26 @@ class MockModel(BaseModel): model_action_metadata( name=googleai_name('model1'), info=google_model_info('model1').model_dump(), - config_schema=GeminiConfigSchema, ), embedder_action_metadata( name=googleai_name('model2'), - info=default_embedder_info('model2'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model2').get('label'), + supports=EmbedderSupports(input=default_embedder_info('model2').get('supports', {}).get('input')), + dimensions=default_embedder_info('model2').get('dimensions'), + ), ), model_action_metadata( name=googleai_name('model3'), info=google_model_info('model3').model_dump(), - config_schema=GeminiConfigSchema, ), embedder_action_metadata( name=googleai_name('model3'), - info=default_embedder_info('model3'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model3').get('label'), + supports=EmbedderSupports(input=default_embedder_info('model3').get('supports', {}).get('input')), + dimensions=default_embedder_info('model3').get('dimensions'), + ), ), ] @@ -496,20 +498,16 @@ def test_vertexai_initialize(vertexai_plugin_instance): name=vertexai_name(version), fn=ANY, metadata=ANY, - config_schema=GeminiConfigSchema, ) for version in ImagenVersion: - ai_mock.define_model.assert_any_call( - name=vertexai_name(version), fn=ANY, metadata=ANY, config_schema=GenerateImagesConfigOrDict - ) + ai_mock.define_model.assert_any_call(name=vertexai_name(version), fn=ANY, metadata=ANY) for version in VertexEmbeddingModels: ai_mock.define_embedder.assert_any_call( name=vertexai_name(version), fn=ANY, metadata=ANY, - config_schema=EmbedContentConfig, ) @@ -609,7 +607,6 @@ def test_vertexai__resolve_model( name=expected_model_name, fn=ANY, metadata=ANY, - config_schema=GenerateImagesConfigOrDict, ) assert key in IMAGE_SUPPORTED_MODELS else: @@ -617,7 +614,6 @@ def test_vertexai__resolve_model( name=expected_model_name, fn=ANY, metadata=ANY, - config_schema=GeminiConfigSchema, ) assert key in SUPPORTED_MODELS @@ -653,7 +649,7 @@ def test_vertexai__resolve_embedder( ) ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, fn=ANY, config_schema=EmbedContentConfig, metadata=default_embedder_info(clean_name) + name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name) ) @@ -680,26 +676,33 @@ class MockModel(BaseModel): model_action_metadata( name=vertexai_name('model1'), info=google_model_info('model1').model_dump(), - config_schema=GeminiConfigSchema, ), embedder_action_metadata( name=vertexai_name('model2_embeddings'), - info=default_embedder_info('model2_embeddings'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model2_embeddings').get('label'), + supports=EmbedderSupports( + input=default_embedder_info('model2_embeddings').get('supports', {}).get('input') + ), + dimensions=default_embedder_info('model2_embeddings').get('dimensions'), + ), ), model_action_metadata( name=vertexai_name('model2_embeddings'), info=google_model_info('model2_embeddings').model_dump(), - config_schema=GeminiConfigSchema, ), embedder_action_metadata( name=vertexai_name('model3_embedder'), - info=default_embedder_info('model3_embedder'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model3_embedder').get('label'), + supports=EmbedderSupports( + input=default_embedder_info('model3_embedder').get('supports', {}).get('input') + ), + dimensions=default_embedder_info('model3_embedder').get('dimensions'), + ), ), model_action_metadata( name=vertexai_name('model3_embedder'), info=google_model_info('model3_embedder').model_dump(), - config_schema=GeminiConfigSchema, ), ] diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 472d1c0a7b..8e31ffce5d 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -23,9 +23,10 @@ import ollama as ollama_api from genkit.ai import GenkitRegistry, Plugin -from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind +from genkit.core.schema import to_json_schema from genkit.plugins.ollama.constants import ( DEFAULT_OLLAMA_SERVER_URL, OllamaAPITypes, @@ -197,7 +198,7 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef ai.define_embedder( name=ollama_name(embedder_ref.name), fn=embedder.embed, - config_schema=ollama_api.Options, + config_schema=to_json_schema(ollama_api.Options), metadata={ 'label': f'Ollama Embedding - {_clean_name}', 'dimensions': embedder_ref.dimensions, @@ -234,14 +235,11 @@ def list_actions(self) -> list[dict[str, str]]: actions.append( embedder_action_metadata( name=ollama_name(_name), - config_schema=ollama_api.Options, - info={ - 'label': f'Ollama Embedding - {_name}', - 'dimensions': None, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {_name}', + supports=EmbedderSupports(input=['text']), + ), ) ) else: diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py index de2fd9d42a..295c39843f 100644 --- a/py/plugins/ollama/tests/test_plugin_api.py +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -24,6 +24,7 @@ from pydantic import BaseModel from genkit.ai import ActionKind, Genkit +from genkit.core.schema import to_json_schema from genkit.plugins.ollama import Ollama, ollama_name from genkit.plugins.ollama.embedders import EmbeddingDefinition from genkit.plugins.ollama.models import ModelDefinition @@ -126,7 +127,7 @@ def test__initialize_embedders(ollama_plugin_instance): ai_mock.define_embedder.assert_called_once_with( name=ollama_name(name), fn=ANY, - config_schema=ollama_api.Options, + config_schema=to_json_schema(ollama_api.Options), metadata={ 'label': f'Ollama Embedding - {name}', 'dimensions': 1024, @@ -165,7 +166,7 @@ def test_resolve_action(kind, name, ollama_plugin_instance): ai_mock.define_embedder.assert_called_once_with( name=ollama_name(name), fn=ANY, - config_schema=ollama_api.Options, + config_schema=to_json_schema(ollama_api.Options), metadata={ 'label': f'Ollama Embedding - {name}', 'dimensions': None, @@ -218,7 +219,7 @@ def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_i ai_mock.define_embedder.assert_called_once_with( name=expected_name, fn=ANY, - config_schema=ollama_api.Options, + config_schema=to_json_schema(ollama_api.Options), metadata={ 'label': f'Ollama Embedding - {clean_name}', 'dimensions': 1024,