Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions py/packages/genkit/src/genkit/ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class while customizing it with any plugins.

from genkit.aio import Channel
from genkit.blocks.document import Document
from genkit.blocks.embedding import EmbedRequest, EmbedResponse
from genkit.blocks.embedding import EmbedderRef
from genkit.blocks.generate import (
StreamingCallback as ModelStreamingCallback,
generate_action,
Expand All @@ -39,6 +39,7 @@ class while customizing it with any plugins.
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
from genkit.core.action import ActionRunContext
from genkit.core.action.types import ActionKind
from genkit.core.typing import EmbedRequest, EmbedResponse
from genkit.types import (
DocumentData,
GenerationCommonConfig,
Expand Down Expand Up @@ -295,10 +296,12 @@ def generate_stream(

async def embed(
self,
embedder: str | None = None,
embedder: str | EmbedderRef | None = None,
documents: list[Document] | None = None,
options: dict[str, Any] | None = None,
) -> EmbedResponse:
embedder_name: str
embedder_config: dict[str, Any] = {}
"""Calculates embeddings for documents.

Args:
Expand All @@ -309,9 +312,22 @@ async def embed(
Returns:
The generated response with embeddings.
"""
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder)
if isinstance(embedder, EmbedderRef):
embedder_name = embedder.name
embedder_config = embedder.config or {}
if embedder.version:
embedder_config['version'] = embedder.version # Handle version from ref
elif isinstance(embedder, str):
embedder_name = embedder
else:
# Handle case where embedder is None
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')

return (await embed_action.arun(EmbedRequest(input=documents, options=options))).response
# Merge options passed to embed() with config from EmbedderRef
final_options = {**(embedder_config or {}), **(options or {})}
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)

return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response

async def retrieve(
self,
Expand All @@ -335,3 +351,30 @@ async def retrieve(
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever)

return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response

async def embed(
self,
embedder: str | EmbedderRef | None = None,
documents: list[Document] | None = None,
options: dict[str, Any] | None = None,
) -> EmbedResponse:
embedder_name: str
embedder_config: dict[str, Any] = {}

if isinstance(embedder, EmbedderRef):
embedder_name = embedder.name
embedder_config = embedder.config or {}
if embedder.version:
embedder_config['version'] = embedder.version # Handle version from ref
elif isinstance(embedder, str):
embedder_name = embedder
else:
# Handle case where embedder is None
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')

# Merge options passed to embed() with config from EmbedderRef
final_options = {**(embedder_config or {}), **(options or {})}

embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)

return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
20 changes: 13 additions & 7 deletions py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import structlog
from pydantic import BaseModel

from genkit.blocks.embedding import EmbedderFn
from genkit.blocks.embedding import EmbedderFn, EmbedderOptions
from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn
from genkit.blocks.formats.types import FormatDef
from genkit.blocks.model import ModelFn, ModelMiddleware
Expand Down Expand Up @@ -458,8 +458,7 @@ def define_embedder(
self,
name: str,
fn: EmbedderFn,
config_schema: BaseModel | dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
options: EmbedderOptions | None = None,
description: str | None = None,
) -> Action:
"""Define a custom embedder action.
Expand All @@ -471,13 +470,20 @@ def define_embedder(
metadata: Optional metadata for the model.
description: Optional description for the embedder.
"""
embedder_meta: dict[str, Any] = metadata if metadata else {}
embedder_meta: dict[str, Any] = {}
if options:
if options.label:
embedder_meta['embedder']['label'] = options.label
if options.dimensions:
embedder_meta['embedder']['dimensions'] = options.dimensions
if options.supports:
embedder_meta['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
if options.config_schema:
embedder_meta['embedder']['customOptions'] = to_json_schema(options.config_schema)

if 'embedder' not in embedder_meta:
embedder_meta['embedder'] = {}

if config_schema:
embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema)

embedder_description = get_func_description(fn, description)
return self.registry.register_action(
name=name,
Expand Down
63 changes: 55 additions & 8 deletions py/packages/genkit/src/genkit/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,76 @@

"""Embedding actions."""

from collections.abc import Callable
from collections.abc import Awaitable, Callable
from typing import Any

from genkit.ai import ActionKind
from pydantic import BaseModel, ConfigDict, Field

from genkit.core.action import ActionMetadata
from genkit.core.action.types import ActionKind
from genkit.core.schema import to_json_schema
from genkit.core.typing import EmbedRequest, EmbedResponse

# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse]

class EmbedderSupports(BaseModel):
"""Embedder capability support."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

input: list[str] | None = None
multilingual: bool | None = None


class EmbedderOptions(BaseModel):
"""Configuration options for an embedder."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
label: str | None = None
supports: EmbedderSupports | None = None
dimensions: int | None = None


class EmbedderRef(BaseModel):
"""Reference to an embedder with configuration."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

name: str
config: Any | None = None
version: str | None = None


EmbedderFn = Callable[[EmbedRequest], EmbedResponse]


def embedder_action_metadata(
name: str,
info: dict[str, Any] | None = None,
config_schema: Any | None = None,
options: EmbedderOptions | None = None,
) -> ActionMetadata:
"""Generates an ActionMetadata for embedders."""
info = info if info is not None else {}
options = options if options is not None else EmbedderOptions()
embedder_metadata_dict = {'embedder': {}}

if options.label:
embedder_metadata_dict['embedder']['label'] = options.label

embedder_metadata_dict['embedder']['dimensions'] = options.dimensions

if options.supports:
embedder_metadata_dict['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)

embedder_metadata_dict['embedder']['customOptions'] = options.config_schema if options.config_schema else None

return ActionMetadata(
kind=ActionKind.EMBEDDER,
name=name,
input_json_schema=to_json_schema(EmbedRequest),
output_json_schema=to_json_schema(EmbedResponse),
metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}},
metadata=embedder_metadata_dict,
)


def create_embedder_ref(name: str, config: dict[str, Any] | None = None, version: str | None = None) -> EmbedderRef:
"""Creates an EmbedderRef instance."""
return EmbedderRef(name=name, config=config, version=version)
Loading
Loading