From 3195701f8367f69277e57255f58e16b6f56ab5e3 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Tue, 13 Jan 2026 10:08:29 +0900 Subject: [PATCH 1/9] feat: opt retrieve node --- server/graph_service/dto/retrieve.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/server/graph_service/dto/retrieve.py b/server/graph_service/dto/retrieve.py index b75c48c9f..e7a498244 100644 --- a/server/graph_service/dto/retrieve.py +++ b/server/graph_service/dto/retrieve.py @@ -4,6 +4,8 @@ from graph_service.dto.common import Message +from typing import Any, Dict, List, Optional, Tuple + class SearchQuery(BaseModel): group_ids: list[str] | None = Field( @@ -11,6 +13,7 @@ class SearchQuery(BaseModel): ) query: str max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') + max_nodes: int = Field(default=10, description='The maximum number of nodes to retrieve') class FactResult(BaseModel): @@ -25,9 +28,18 @@ class FactResult(BaseModel): class Config: json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()} +class NodeResult(BaseModel): + uuid: str + name: str + labels: list[str] + summary: str + attributes: dict[str, Any] = Field(default={}) + created_at: datetime class SearchResults(BaseModel): - facts: list[FactResult] + facts: Optional[list[FactResult]] + nodes: Optional[list[NodeResult]] + community: str class GetMemoryRequest(BaseModel): @@ -43,3 +55,4 @@ class GetMemoryRequest(BaseModel): class GetMemoryResponse(BaseModel): facts: list[FactResult] = Field(..., description='The facts that were retrieved from the graph') + nodes: Optional[list[NodeResult]] = Field(..., description='The nodes that were retrieved from the graph') From 65fd9ec2e17f5f1946fd214730095e8f8bd52063 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Tue, 13 Jan 2026 10:11:41 +0900 Subject: [PATCH 2/9] feat: opt retrieve node --- graphiti_core/prompts/extract_nodes.py | 5 +- server/graph_service/dto/__init__.py | 31 +++- server/graph_service/dto/common.py | 137 ++++++++++++++++++ server/graph_service/dto/ingest.py | 11 +- server/graph_service/routers/ingest.py | 41 +++++- server/graph_service/routers/retrieve.py | 174 ++++++++++++++++++++++- server/graph_service/zep_graphiti.py | 13 +- 7 files changed, 398 insertions(+), 14 deletions(-) diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 1fb55c1cc..b8cd2dcd9 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -285,14 +285,13 @@ def extract_summary(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', - content='You are a helpful assistant that extracts entity summaries from the provided text.', + content='You are a helpful assistant that extracts entity summaries from the provided text. Always summarize in Korean.', ), Message( role='user', content=f""" Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity - from the messages and relevant information from the existing summary. - + from the messages and relevant information from the existing summary. Always summarize in Korean.' {summary_instructions} diff --git a/server/graph_service/dto/__init__.py b/server/graph_service/dto/__init__.py index 375c9c432..76e412c0e 100644 --- a/server/graph_service/dto/__init__.py +++ b/server/graph_service/dto/__init__.py @@ -1,6 +1,6 @@ -from .common import Message, Result -from .ingest import AddEntityNodeRequest, AddMessagesRequest -from .retrieve import FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults +from .common import Message, Result, Text, Person, Organization, Product, Concept, Location, Datetime, Category, Image, Web, entity_type, edge_type, edge_type_maps, MemberOf, ManagedBy,Published,ScheduledOn,LocatedIn,References,Uses,BelongsTo,HasImage, HasLink +from .ingest import AddEntityNodeRequest, AddMessagesRequest, AddTextsRequest +from .retrieve import NodeResult, FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults __all__ = [ 'SearchQuery', @@ -9,7 +9,32 @@ 'AddEntityNodeRequest', 'SearchResults', 'FactResult', + 'NodeResult', 'Result', 'GetMemoryRequest', 'GetMemoryResponse', + 'AddTextsRequest', + 'Text', + 'Person', + 'Organization', + 'Product', + 'Concept', + 'Location', + 'Datetime', + 'Category', + 'Image', + 'Web', + 'entity_type', + 'edge_type', + 'edge_type_maps', + 'MemberOf', + 'ManagedBy', + 'Published', + 'ScheduledOn', + 'LocatedIn', + 'References', + 'Uses', + 'BelongsTo', + 'HasImage', + 'HasLink' ] diff --git a/server/graph_service/dto/common.py b/server/graph_service/dto/common.py index 5103470e6..434a06d3f 100644 --- a/server/graph_service/dto/common.py +++ b/server/graph_service/dto/common.py @@ -26,3 +26,140 @@ class Message(BaseModel): source_description: str = Field( default='', description='The description of the source of the message' ) + +class Text(BaseModel): + content: str = Field(..., description='The content of the text') + uuid: str | None = Field(default=None, description='The uuid of the text (optional)') + name: str = Field( + default='', description='The name of the episodic node for the text (optional)' + ) + timestamp: datetime = Field(default_factory=utc_now, description='The timestamp of the text') + source_description: str = Field( + default='', description='The description of the source of the text' + ) + +class Person(BaseModel): + """PEOPLE: 성함 및 직함 포함 (예: '홍길동 팀장')""" + full_name: str = Field(..., description="Full name and title") + +class Organization(BaseModel): + """ORGANIZATIONS: 정식 기관/업체 명칭 (예: '단지1 로비')""" + org_name: str = Field(..., description="Normalized legal name") + +class Asset(BaseModel): + """ASSETS: 공지사항, 게시물 등 주요 정보 객체 (가장 중심이 되는 노드)""" + asset_name: str = Field(..., description="Title of the notice or post") + web_url: str | None = Field(None, description="Related link URL") + image_url: str | None = Field(None, description="Main image URL") + +class Product(BaseModel): + """PRODUCTS: 구체적인 제품 이름""" + product_name: str = Field(..., description="Specific product name") + +class Service(BaseModel): + """SERVICES: 구체적인 서비스 명칭""" + service_name: str = Field(..., description="Specific service name") + +class Concept(BaseModel): + """CONCEPTS: 도메인 전문 용어 또는 개념""" + concept_label: str = Field(..., description="Domain-specific concept") + +class Location(BaseModel): + """LOCATIONS: 구체적인 장소 (예: '단지1 로비', '지하주차장')""" + location_name: str = Field(..., description="Specific physical location") + +class Datetime(BaseModel): + """TEMPORAL ENTITIES: YYYY-MM-DD HH:mm:ss 형식의 정규화된 날짜 노드""" + date_val: str = Field(..., description="Normalized ISO 8601 string") + +class Category(BaseModel): + """CATEGORY: 문서 분류""" + category_name: str = Field(..., description="Name of the category") + +class Image(BaseModel): + """IMAGE: 첨부 이미지 정보""" + url: str = Field(..., description="Direct image link") + alt_text: str | None = Field(None, description="Description of the image") + +class Web(BaseModel): + """WEB: 외부 링크 정보""" + url: str = Field(..., description="Direct web link") + +# --- 엔티티 타입 등록 --- +entity_type = { + "Person": Person, + "Organization": Organization, + "Asset": Asset, + "Product": Product, + "Service": Service, + "Concept": Concept, + "Location": Location, + "Datetime": Datetime, + "Category": Category, + "Image": Image, + "Web": Web, +} + +# --- 엣지(Edge) 클래스 정의 --- +class MemberOf(BaseModel): fact: str = Field(default="MEMBER_OF") +class ManagedBy(BaseModel): fact: str = Field(default="MANAGED_BY") +class Published(BaseModel): fact: str = Field(default="PUBLISHED") +class ScheduledOn(BaseModel): fact: str = Field(default="SCHEDULED_ON") +class LocatedIn(BaseModel): fact: str = Field(default="LOCATED_IN") +class References(BaseModel): fact: str = Field(default="REFERENCES") +class Uses(BaseModel): fact: str = Field(default="USES") +class BelongsTo(BaseModel): fact: str = Field(default="BELONGS_TO") +class HasImage(BaseModel): fact: str = Field(default="HAS_IMAGE") +class HasLink(BaseModel): fact: str = Field(default="HAS_LINK") + +edge_type = { + "MEMBER_OF": MemberOf, + "MANAGED_BY": ManagedBy, + "PUBLISHED": Published, + "SCHEDULED_ON": ScheduledOn, + "LOCATED_IN": LocatedIn, + "REFERENCES": References, + "USES": Uses, + "BELONGS_TO": BelongsTo, + "HAS_IMAGE": HasImage, + "HAS_LINK": HasLink +} + +edge_type_maps = { + ("Person", "Organization"): ["MEMBER_OF"], + ("Organization", "Category"): ["MEMBER_OF"], + + # 관리 및 발행 관계 + ("Asset", "Organization"): ["MANAGED_BY"], + ("Asset", "Person"): ["MANAGED_BY"], + ("Organization", "Asset"): ["PUBLISHED"], + ("Person", "Asset"): ["PUBLISHED"], + + # 시간 및 장소 (게시물 중심 연결) + ("Asset", "Datetime"): ["SCHEDULED_ON"], + ("Concept", "Datetime"): ["SCHEDULED_ON"], + ("Asset", "Location"): ["LOCATED_IN"], + ("Organization", "Location"): ["LOCATED_IN"], + + # 참조 및 분류 + ("Asset", "Asset"): ["REFERENCES"], + ("Asset", "Web"): ["HAS_LINK", "REFERENCES"], + ("Asset", "Category"): ["BELONGS_TO"], + + # 이미지 및 미디어 연결 (개선사항 반영) + ("Asset", "Image"): ["HAS_IMAGE"], + + # 도메인 지식 + ("Organization", "Concept"): ["USES"], + ("Person", "Concept"): ["USES"], + ("Asset", "Concept"): ["USES"], +} +# class Doc(BaseModel): +# uuid: str = Field(..., description='The uuid of the episode') +# group_id: str = Field(..., description='The group id of the episode') +# name: str = Field(..., description='The name of the episode') +# episode_body: str = Field(..., description='The body of the episode') +# reference_time: datetime = Field(default_factory=utc_now, description='The reference time of the episode') +# source: str = Field(..., description='The source of the episode') +# source_description: str = Field(...,) + diff --git a/server/graph_service/dto/ingest.py b/server/graph_service/dto/ingest.py index 9b0159c85..450690022 100644 --- a/server/graph_service/dto/ingest.py +++ b/server/graph_service/dto/ingest.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, Field - -from graph_service.dto.common import Message - +from typing import List +from graph_service.dto.common import Message, Text class AddMessagesRequest(BaseModel): group_id: str = Field(..., description='The group id of the messages to add') @@ -13,3 +12,9 @@ class AddEntityNodeRequest(BaseModel): group_id: str = Field(..., description='The group id of the node to add') name: str = Field(..., description='The name of the node to add') summary: str = Field(default='', description='The summary of the node to add') + + +class AddTextsRequest(BaseModel): + group_id: str = Field(..., description='The group id of the texts to add') + texts: List[Text] = Field(..., description='The texts to add') + prompt: str = Field(default='', description='The custom extraction prompt for the text') diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index d03563105..3f6a1bdc7 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -6,7 +6,7 @@ from graphiti_core.nodes import EpisodeType # type: ignore from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore -from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result +from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result, AddTextsRequest, Text, edge_type_maps, edge_type, entity_type from graph_service.zep_graphiti import ZepGraphitiDep @@ -70,6 +70,45 @@ async def add_messages_task(m: Message): return Result(message='Messages added to processing queue', success=True) +@router.post('/build-communities', status_code=status.HTTP_202_ACCEPTED) +async def build_communities(group_id: str, graphiti: ZepGraphitiDep): + """ + 특정 그룹의 노드들을 분석하여 커뮤니티(주제별 그룹)를 형성하고 요약합니다. + """ + async def build_task(): + await graphiti.build_communities(group_ids=[group_id]) + + await async_worker.queue.put(build_task) + + return Result(message='Community building started in background', success=True) + + +@router.post('/texts', status_code=status.HTTP_202_ACCEPTED) +async def add_texts( + request: AddTextsRequest, + graphiti: ZepGraphitiDep, +): + + async def add_texts_task(m: Text): + await graphiti.add_episode( + uuid=m.uuid, + group_id=request.group_id, + name=m.name, + episode_body=m.content, # specific strategy content + source_description=m.source_description, + reference_time=m.timestamp, + source=EpisodeType.text, + entity_types=entity_type, + edge_types=edge_type, + edge_type_map=edge_type_maps, + custom_extraction_instructions=request.prompt + ) + + for m in request.texts: + await async_worker.queue.put(partial(add_texts_task, m)) + + return Result(message='Texts added to processing queue', success=True) + @router.post('/entity-node', status_code=status.HTTP_201_CREATED) async def add_entity_node( request: AddEntityNodeRequest, diff --git a/server/graph_service/routers/retrieve.py b/server/graph_service/routers/retrieve.py index d42df2a33..ce1eaf559 100644 --- a/server/graph_service/routers/retrieve.py +++ b/server/graph_service/routers/retrieve.py @@ -9,7 +9,40 @@ SearchQuery, SearchResults, ) -from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge +from graph_service.zep_graphiti import ZepGraphitiDep, get_fact_result_from_edge, get_node_result_from_entity +from graphiti_core.nodes import EntityNode +from graphiti_core.search.search_config import SearchConfig, NodeSearchConfig, CommunitySearchMethod, CommunitySearchConfig +from graphiti_core.search.search_filters import SearchFilters + + +import asyncio +import json +from datetime import datetime, timezone +from typing import Any, Dict, List + +from graphiti_core.embedder import EmbedderClient + +from graphiti_core.nodes import EpisodeType, EpisodicNode, EntityNode, CommunityNode +from graphiti_core.edges import EntityEdge, EpisodicEdge, CommunityEdge, create_entity_edge_embeddings + +from graphiti_core.search.search_config_recipes import ( + NODE_HYBRID_SEARCH_RRF, + COMBINED_HYBRID_SEARCH_RRF, + NODE_HYBRID_SEARCH_NODE_DISTANCE, + EDGE_HYBRID_SEARCH_CROSS_ENCODER, + NODE_HYBRID_SEARCH_EPISODE_MENTIONS, + EDGE_HYBRID_SEARCH_RRF, # 🔹 fact 전용 + COMMUNITY_HYBRID_SEARCH_RRF, +) +from graphiti_core.search.search_config import ( + EdgeSearchConfig, + EdgeSearchMethod, + EdgeReranker, + NodeSearchMethod, + NodeReranker, +) +from graphiti_core.search.search_filters import SearchFilters, DateFilter, ComparisonOperator +from graphiti_core.search.search_utils import node_bfs_search router = APIRouter() @@ -23,7 +56,7 @@ async def search(query: SearchQuery, graphiti: ZepGraphitiDep): ) facts = [get_fact_result_from_edge(edge) for edge in relevant_edges] return SearchResults( - facts=facts, + facts=facts, community='', nodes=None ) @@ -53,7 +86,7 @@ async def get_memory( num_results=request.max_facts, ) facts = [get_fact_result_from_edge(edge) for edge in result] - return GetMemoryResponse(facts=facts) + return GetMemoryResponse(facts=facts, nodes=None) def compose_query_from_messages(messages: list[Message]): @@ -61,3 +94,138 @@ def compose_query_from_messages(messages: list[Message]): for message in messages: combined_query += f'{message.role_type or ""}({message.role or ""}): {message.content}\n' return combined_query + + +# TODO - 세가지 검색 각각 실행 후 결과를 종합해야함. +@router.post('/search-fused', status_code=status.HTTP_200_OK) +async def search_fused(query: SearchQuery, graphiti: ZepGraphitiDep): + # --- [1. 검색 함수 정의부] --- + async def get_text_matches(): + config = SearchConfig( + node_config=NodeSearchConfig(search_methods=[NodeSearchMethod.bm25], reranker=NodeReranker.rrf), + limit=query.max_nodes + ) + res = await graphiti._search(query=query.query, config=config, group_ids=query.group_ids) + return res.nodes or [] + + async def get_vector_matches(): + config = SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.rrf, + sim_min_score=0.6 + ), + limit=query.max_nodes + ) + res = await graphiti._search(query=query.query, config=config, group_ids=query.group_ids) + return res.nodes or [] + + async def get_community_matches(): + config = SearchConfig( + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.cosine_similarity], + ), + limit=query.max_nodes + ) + res = await graphiti._search(query=query.query, config=COMMUNITY_HYBRID_SEARCH_RRF, group_ids=query.group_ids) + return res.communities or [] + + async def get_graph_matches(seed_nodes): + """복합 관계를 찾기 위한 BFS 탐색""" + if not seed_nodes: return [] + seed_uuids = [n.uuid for n in seed_nodes[:9]] + config = SearchConfig( + node_config=NodeSearchConfig(search_methods=[NodeSearchMethod.bfs], bfs_max_depth=2), + limit=query.max_nodes + ) + res = await graphiti._search( + query=query.query, + config=config, + group_ids=query.group_ids, + bfs_origin_node_uuids=seed_uuids + ) + return res.nodes or [] + + # --- [2. 실행 로직] --- + # 먼저 텍스트, 벡터, 커뮤니티를 병렬로 가져옵니다. + text_nodes, vector_nodes, community_results = await asyncio.gather( + get_text_matches(), + get_vector_matches(), + get_community_matches() + ) + + # [복구된 부분] 벡터 결과를 기반으로 연관 노드들을 추가로 긁어옵니다. + graph_nodes = await get_graph_matches(vector_nodes) + + # --- [3. Fusion 및 정렬] --- + all_nodes: Dict[str, Dict[str, Any]] = {} + weights = {'text': 0.2, 'vector': 0.4, 'graph': 0.4} + + def process_matches(nodes, source_key): + weight = weights[source_key] + for i, node in enumerate(nodes): + rank_score = (1.0 / (i + 1)) * weight + if node.uuid not in all_nodes: + all_nodes[node.uuid] = {'node': node, 'total_score': 0.0, 'sources': []} + all_nodes[node.uuid]['total_score'] += rank_score + all_nodes[node.uuid]['sources'].append(source_key) + + process_matches(text_nodes, 'text') + process_matches(vector_nodes, 'vector') + process_matches(graph_nodes, 'graph') + + ranked_results = sorted(all_nodes.values(), key=lambda x: x['total_score'], reverse=True) + final_nodes = [get_node_result_from_entity(item['node']) for item in ranked_results[:query.max_nodes]] + + # --- [4. 결과 반환] --- + community_facts = [f"[테마 요약: {c.name}] {c.summary}" for c in community_results] + + debug_data = { + "query": query.query, + "counts": { + "text": len(text_nodes), + "vector": len(vector_nodes), + "graph": len(graph_nodes) + }, + "final_ranking": [ + {"name": n.name, "score": item['total_score'], "sources": item['sources']} + for item, n in zip(ranked_results, [item['node'] for item in ranked_results]) + ] + } + + with open('search-fused-debug.json', 'w', encoding='utf-8') as f: + json.dump(debug_data, f, ensure_ascii=False, indent=4) + + with open('search-fused.json', 'w', encoding='utf-8') as f: + json.dump({"community_facts": community_facts, "all_nodes": str(all_nodes), "reranked_results": str(ranked_results),"final_nodes": str(final_nodes)}, f, ensure_ascii=False, indent=4) + + + return SearchResults(nodes=final_nodes, community=str(community_facts), facts=None) + +@router.post('/search-node', status_code=status.HTTP_200_OK) +async def search_node(query: SearchQuery, graphiti: ZepGraphitiDep): + search_config= SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.rrf, + mmr_lambda=1, + bfs_max_depth=1, + sim_min_score=0.6, + ), + limit=24, + ) + + search_filter = SearchFilters(node_labels=["Doc"]) + search_config.limit = query.max_nodes + relevant_nodes = await graphiti._search( + query=query.query, + config=search_config, + group_ids=query.group_ids, + search_filter=search_filter + ) + + nodes = [get_node_result_from_entity(node) for node in relevant_nodes.nodes] + + return SearchResults( + nodes=nodes, facts=None, community='' + ) \ No newline at end of file diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 097c9f391..5f285efdc 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -9,7 +9,7 @@ from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore from graph_service.config import ZepEnvDep -from graph_service.dto import FactResult +from graph_service.dto import FactResult, NodeResult logger = logging.getLogger(__name__) @@ -110,5 +110,16 @@ def get_fact_result_from_edge(edge: EntityEdge): expired_at=edge.expired_at, ) +def get_node_result_from_entity(node: EntityNode): + return NodeResult( + uuid=node.uuid, + name=node.name, + labels=node.labels, + summary=node.summary, + attributes=node.attributes, + created_at=node.created_at, + ) + + ZepGraphitiDep = Annotated[ZepGraphiti, Depends(get_graphiti)] From 8d802bba4825c29b756b95f60c27f96845144fa5 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Thu, 15 Jan 2026 09:51:45 +0900 Subject: [PATCH 3/9] feat: add context aware normalize --- graphiti_core/graphiti.py | 58 +++++++++++ graphiti_core/graphiti_types.py | 3 + graphiti_core/normalize.py | 95 +++++++++++++++++++ graphiti_core/prompts/extract_nodes.py | 26 ++++- .../utils/maintenance/dedup_helpers.py | 54 +++++++++++ .../utils/maintenance/node_operations.py | 4 +- 6 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 graphiti_core/normalize.py diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index af7d0b344..b5acbd4cf 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -98,6 +98,8 @@ ) from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types +from graphiti_core.normalize import EntityResolver + logger = logging.getLogger(__name__) load_dotenv() @@ -140,6 +142,7 @@ def __init__( max_coroutines: int | None = None, tracer: Tracer | None = None, trace_span_prefix: str = 'graphiti', + resolver: EntityResolver | None = None, ): """ Initialize a Graphiti instance. @@ -224,12 +227,19 @@ def __init__( # Set tracer on clients self.llm_client.set_tracer(self.tracer) + # Initialize resolver + if resolver: + self.resolver = resolver + else: + self.resolver = EntityResolver() + self.clients = GraphitiClients( driver=self.driver, llm_client=self.llm_client, embedder=self.embedder, cross_encoder=self.cross_encoder, tracer=self.tracer, + resolver=self.resolver ) # Capture telemetry event @@ -710,6 +720,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): with self.tracer.start_span('add_episode') as span: try: # Retrieve previous episodes for context + # STEP.1 맥락을 위해 이전 Episode를 검색함 previous_episodes = ( await self.retrieve_episodes( reference_time, @@ -722,6 +733,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # Get or create episode + # STEP.2 저장할 EpisodeNode get OR create. episode = ( await EpisodicNode.get_by_uuid(self.driver, uuid) if uuid is not None @@ -738,6 +750,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # Create default edge type map + # STEP.3 엣지 타입의 기본 맵을 생성. edge_type_map_default = ( {('Entity', 'Entity'): list(edge_types.keys())} if edge_types is not None @@ -745,6 +758,16 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # Extract and resolve nodes + # STEP.4 노드 추출 + # - extract_nodes: LLM을 사용하여 에피소드 텍스트에서 엔티티(노드)를 추출합니다. + # - entity type context의 name, descriptions 정리하기. + # 1. 엔티티 타입 컨텍스트 정의 (기본 'Entity' + 제공된 타입들) + # 2. LLM을 위한 컨텍스트 준비 (에피소드 내용, 타임스탬프, 이전 에피소드, 지침, 엔티티 타입 등) + # 3. 에피소드 소스 타입(message, text, json)에 따라 적절한 프롬프트로 LLM 호출 + # 4. LLM 응답을 ExtractedEntities 모델로 파싱 + # 5. 추출된 엔티티 필터링 (이름이 비어있는 경우 제거) + # 6. 추출된 데이터를 EntityNode 객체로 변환 (타입 매핑 및 제외 처리 포함) + # 7. 최종 추출된 노드 리스트 반환 extracted_nodes = await extract_nodes( self.clients, episode, @@ -754,6 +777,15 @@ async def add_episode_endpoint(episode_data: EpisodeData): custom_extraction_instructions, ) + # Resolve extracted nodes (deduplication) + # STEP.5 추출된 노드 중복 제거 및 해결 + # - resolve_extracted_nodes: 추출된 노드들을 기존 그래프의 노드들과 비교하여 중복을 제거하고 해결합니다. + # - _collect_candidate_nodes: 각 추출된 노드 이름으로 기존 그래프에서 후보 노드들을 검색(Hybrid Search)합니다. + # - _build_candidate_indexes: 후보 노드들에 대한 인덱스(MinHash 등)를 생성합니다. + # - _resolve_with_similarity: 텍스트 유사도(Jaccard)를 기반으로 결정적인 중복을 해결합니다. + # - _resolve_with_llm: 유사하지만 확실하지 않은 경우 LLM(dedupe_nodes.nodes)을 사용하여 중복 여부를 판단합니다. + # - filter_existing_duplicate_of_edges: 이미 'IS_DUPLICATE_OF' 엣지로 연결된 중복 관계를 확인합니다. + nodes, uuid_map, _ = await resolve_extracted_nodes( self.clients, extracted_nodes, @@ -763,6 +795,15 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # Extract and resolve edges in parallel with attribute extraction + # STEP.6 엣지 추출 및 해결 (Layer 2) + # - extract_edges: LLM(extract_edges.edge)을 사용하여 노드 간의 관계를 추출합니다. + # - resolve_edge_pointers: 노드 중복 제거 결과(uuid_map)를 바탕으로 엣지의 소스/타겟 UUID를 갱신합니다. + # - resolve_extracted_edges: 추출된 엣지를 검증하고 기존 엣지와 병합하거나 모순을 해결합니다. + # - create_entity_edge_embeddings: 엣지의 사실(fact)에 대한 임베딩을 생성합니다. + # - EntityEdge.get_between_nodes: 두 노드 사이에 이미 존재하는 엣지를 DB에서 조회합니다. + # - search (related_edges): 추출된 엣지와 관련된 기존 엣지들을 검색합니다. + # - search (invalidation_candidates): 모순될 가능성이 있는 기존 엣지들을 검색합니다. + # - resolve_extracted_edge: LLM(dedupe_edges.resolve_edge)을 사용하여 중복 엣지 병합 및 모순된 엣지(invalidated)를 식별합니다. resolved_edges, invalidated_edges = await self._extract_and_resolve_edges( episode, extracted_nodes, @@ -776,6 +817,12 @@ async def add_episode_endpoint(episode_data: EpisodeData): ) # Extract node attributes + # STEP.7 노드 속성 추출 + # - extract_attributes_from_nodes: 해결된 노드들의 추가 속성과 요약을 추출합니다. + # - extract_attributes_from_node: 각 노드에 대해 속성 및 요약 추출을 수행합니다. + # - _extract_entity_attributes: LLM(extract_nodes.extract_attributes)을 사용하여 노드의 구조적 속성을 추출합니다. + # - _extract_entity_summary: LLM(extract_nodes.extract_summary)을 사용하여 노드에 대한 요약을 생성/갱신합니다. + # - create_entity_node_embeddings: 노드 이름에 대한 임베딩을 생성합니다. hydrated_nodes = await extract_attributes_from_nodes( self.clients, nodes, episode, previous_episodes, entity_types ) @@ -783,11 +830,22 @@ async def add_episode_endpoint(episode_data: EpisodeData): entity_edges = resolved_edges + invalidated_edges # Process and save episode data + # STEP.8 에피소드 데이터 처리 및 저장 + # - _process_episode_data: 에피소드 노드와 엣지, 그리고 처리된 엔티티 노드와 엣지들을 DB에 저장합니다. + # - build_episodic_edges: 에피소드 노드와 엔티티 노드들을 연결하는 'MENTIONS' 엣지를 생성합니다. + # - add_nodes_and_edges_bulk: 노드, 엣지, 에피소드 데이터를 일괄적으로 DB에 저장(Merge/Create)합니다. episodic_edges, episode = await self._process_episode_data( episode, hydrated_nodes, entity_edges, now ) # Update communities if requested + # STEP.9 커뮤니티 업데이트 (요청 시) + # - update_community: 변경된 노드들이 속한 커뮤니티 정보를 갱신합니다. + # - determine_entity_community: 해당 노드가 속할 커뮤니티를 찾거나 새로 생성해야 하는지 결정합니다. + # - summarize_pair: LLM(summarize_nodes.summarize_pair)을 사용하여 노드 요약과 기존 커뮤니티 요약을 통합합니다. + # - generate_summary_description: 통합된 요약을 바탕으로 커뮤니티의 새로운 이름/설명을 생성합니다. + # - build_community_edges: 노드와 커뮤니티를 연결하는 'HAS_MEMBER' 엣지를 생성합니다. + # - community.save: 갱신된 커뮤니티 정보를 DB에 저장합니다. communities = [] community_edges = [] if update_communities: diff --git a/graphiti_core/graphiti_types.py b/graphiti_core/graphiti_types.py index bb9ea4689..ba4b255c3 100644 --- a/graphiti_core/graphiti_types.py +++ b/graphiti_core/graphiti_types.py @@ -22,6 +22,7 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.tracer import Tracer +from graphiti_core.normalize import EntityResolver class GraphitiClients(BaseModel): driver: GraphDriver @@ -31,3 +32,5 @@ class GraphitiClients(BaseModel): tracer: Tracer model_config = ConfigDict(arbitrary_types_allowed=True) + + resolver: EntityResolver \ No newline at end of file diff --git a/graphiti_core/normalize.py b/graphiti_core/normalize.py new file mode 100644 index 000000000..dc1f1e027 --- /dev/null +++ b/graphiti_core/normalize.py @@ -0,0 +1,95 @@ +import numpy as np +from sklearn.cluster import DBSCAN + +class EntityResolver: + """Production entity resolution with context-aware disambiguation""" + + def __init__(self): + self.entity_cache = {} + self.canonical_map = {} + + def compute_entity_similarity(self, entity1, entity2): + """Compute similarity considering both text and semantic context""" + + # Exact match gets high score + if entity1['surface_form'].lower() == entity2['surface_form'].lower(): + base_score = 0.9 + else: + # Fuzzy match on surface form + from difflib import SequenceMatcher + base_score = SequenceMatcher( + None, + entity1['surface_form'].lower(), + entity2['surface_form'].lower() + ).ratio() + + # Type mismatch penalty + if entity1['type'] != entity2['type']: + base_score *= 0.3 + + # Context similarity boost + if 'features' in entity1 and 'features' in entity2: + shared_features = set(entity1['features'].keys()) & set(entity2['features'].keys()) + if shared_features: + # Features match increases confidence + feature_match_score = sum( + 1 for k in shared_features + if entity1['features'][k] == entity2['features'][k] + ) / len(shared_features) + base_score = 0.7 * base_score + 0.3 * feature_match_score + + return base_score + def resolve_entities(self, all_entities, similarity_threshold=0.75): + """Cluster entities into canonical forms using DBSCAN""" + + n = len(all_entities) + if n == 0: + return {} + + # Build similarity matrix + similarity_matrix = np.zeros((n, n)) + for i in range(n): + for j in range(i+1, n): + sim = self.compute_entity_similarity(all_entities[i], all_entities[j]) + similarity_matrix[i,j] = sim + similarity_matrix[j,i] = sim + + # Convert similarity to distance for DBSCAN + distance_matrix = 1 - similarity_matrix + + # Cluster entities + clustering = DBSCAN( + eps=1-similarity_threshold, + min_samples=1, + metric='precomputed' + ).fit(distance_matrix) + + # Create canonical entities + canonical_entities = {} + for cluster_id in set(clustering.labels_): + cluster_members = [ + all_entities[i] for i, label in enumerate(clustering.labels_) + if label == cluster_id + ] + + # Most common surface form becomes canonical + surface_forms = [e['surface_form'] for e in cluster_members] + canonical_form = max(set(surface_forms), key=surface_forms.count) + + canonical_entities[canonical_form] = { + 'canonical_name': canonical_form, + 'type': cluster_members[0]['type'], + 'variant_forms': list(set(surface_forms)), + 'occurrences': len(cluster_members), + 'contexts': [e['context'] for e in cluster_members[:5]] # Sample contexts + } + + # Map all variants to canonical form + for variant in surface_forms: + self.canonical_map[variant] = canonical_form + + return canonical_entities + + def get_canonical_form(self, surface_form): + """Get canonical entity name for any surface form""" + return self.canonical_map.get(surface_form, surface_form) \ No newline at end of file diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index b8cd2dcd9..d8713814b 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -24,6 +24,13 @@ from .prompt_helpers import to_prompt_json from .snippets import summary_instructions +class EntityFeature(BaseModel): + # OpenAI Strict 모드를 위해 모든 필드에 additionalProperties: False가 적용되도록 함 + key: str = Field(..., description="특징의 이름 (예: 직업, 위치, 상태)") + value: str = Field(..., description="특징의 값 (예: 회장, 서울, 활성)") + + class Config: + extra = "forbid" # OpenAI가 요구하는 additionalProperties: false 설정 class ExtractedEntity(BaseModel): name: str = Field(..., description='Name of the extracted entity') @@ -31,11 +38,23 @@ class ExtractedEntity(BaseModel): description='ID of the classified entity type. ' 'Must be one of the provided entity_type_id integers.', ) + context: str | None = Field( + ..., description='Surrounding context (the sentence containing the entity)' + ) + # dict 대신 모델 리스트로 변경하여 스키마 충돌 방지 + features: list[EntityFeature] = Field( + ..., + description="엔티티 식별 특징 리스트. 없으면 빈 리스트 [] 반환" + ) + + class Config: + extra = "forbid" class ExtractedEntities(BaseModel): extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities') - + class Config: + extra = "forbid" class MissedEntities(BaseModel): missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted") @@ -285,13 +304,14 @@ def extract_summary(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', - content='You are a helpful assistant that extracts entity summaries from the provided text. Always summarize in Korean.', + content='You are a helpful assistant that extracts entity summaries from the provided text. Always summary to korean.', ), Message( role='user', content=f""" Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity - from the messages and relevant information from the existing summary. Always summarize in Korean.' + from the messages and relevant information from the existing summary. Always summary to korean. + {summary_instructions} diff --git a/graphiti_core/utils/maintenance/dedup_helpers.py b/graphiti_core/utils/maintenance/dedup_helpers.py index b8ce68b89..95463f953 100644 --- a/graphiti_core/utils/maintenance/dedup_helpers.py +++ b/graphiti_core/utils/maintenance/dedup_helpers.py @@ -25,6 +25,8 @@ from hashlib import blake2b from typing import TYPE_CHECKING +from graphiti_core.graphiti_types import GraphitiClients + if TYPE_CHECKING: from graphiti_core.nodes import EntityNode @@ -245,6 +247,57 @@ def _resolve_with_similarity( state.unresolved_indices.append(idx) +def _resolve_with_dbscan_and_similarity( + clients: GraphitiClients, + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, +) -> None: + """ + DBSCAN(수학적 정규화)을 우선 실행하고, 실패 시 기존 유사도 알고리즘을 수행하는 통합 함수 + """ + + # 1. DBSCAN (EntityResolver) 로직 실행 + if hasattr(clients, 'resolver') and clients.resolver is not None: + all_candidate_pool = extracted_nodes + indexes.existing_nodes + + # Resolver용 데이터 정규화 준비 + resolver_input = [] + for node in all_candidate_pool: + meta = node.attributes.get('extraction_features', {}) + resolver_input.append({ + "surface_form": node.name, + "context": meta.get('context', ""), + "features": meta.get('features', {}), + "type": meta.get('type', "Entity") + }) + + # DBSCAN 실행 (클러스터링 수행) + clients.resolver.resolve_entities(resolver_input) + + # Canonical Form 기반 매칭 시도 + for idx, node in enumerate(extracted_nodes): + canonical_name = clients.resolver.get_canonical_form(node.name) + + # 1-1. 정규화된 이름이 기존 노드 중에 있는지 확인 + # (기존 indexes.normalized_existing 활용하여 속도 향상) + exact_normalized = _normalize_string_exact(canonical_name) + existing_matches = indexes.normalized_existing.get(exact_normalized, []) + + if existing_matches: + # 가장 적절한 기존 노드와 매칭 (첫 번째 노드 혹은 가장 신뢰도 높은 노드) + match = existing_matches[0] + state.resolved_nodes[idx] = match + state.uuid_map[node.uuid] = match.uuid + if match.uuid != node.uuid: + state.duplicate_pairs.append((node, match)) + else: + # 정규화는 되었으나 기존 노드 중 매칭이 없으면 다음 단계(유사도 혹은 LLM)를 위해 보류 + state.unresolved_indices.append(idx) + else: + # Resolver가 없는 경우에만 기존 Jaccard/MinHash 로직 실행 + _resolve_with_similarity(extracted_nodes, indexes, state) + __all__ = [ 'DedupCandidateIndexes', @@ -259,4 +312,5 @@ def _resolve_with_similarity( '_FUZZY_JACCARD_THRESHOLD', '_build_candidate_indexes', '_resolve_with_similarity', + '_resolve_with_dbscan_and_similarity' ] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 1a75d70de..3b9347fc5 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -49,6 +49,7 @@ DedupResolutionState, _build_candidate_indexes, _resolve_with_similarity, + _resolve_with_dbscan_and_similarity ) from graphiti_core.utils.maintenance.edge_operations import ( filter_existing_duplicate_of_edges, @@ -96,6 +97,7 @@ async def extract_nodes( start = time() llm_client = clients.llm_client + # 1. 엔티티 타입 컨텍스트 설정 entity_types_context = [ { 'entity_type_id': 0, @@ -396,7 +398,7 @@ async def resolve_extracted_nodes( unresolved_indices=[], ) - _resolve_with_similarity(extracted_nodes, indexes, state) + _resolve_with_dbscan_and_similarity(clients, extracted_nodes, indexes, state) await _resolve_with_llm( llm_client, From 74fd1d31fd709b0ebc12c39fb81b8bacd7db54e2 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Thu, 15 Jan 2026 09:52:24 +0900 Subject: [PATCH 4/9] feat: opt type of knowledge --- server/graph_service/dto/__init__.py | 38 +++-- server/graph_service/dto/common.py | 224 +++++++++++++++------------ server/graph_service/dto/retrieve.py | 4 + 3 files changed, 153 insertions(+), 113 deletions(-) diff --git a/server/graph_service/dto/__init__.py b/server/graph_service/dto/__init__.py index 76e412c0e..2c047f695 100644 --- a/server/graph_service/dto/__init__.py +++ b/server/graph_service/dto/__init__.py @@ -1,4 +1,4 @@ -from .common import Message, Result, Text, Person, Organization, Product, Concept, Location, Datetime, Category, Image, Web, entity_type, edge_type, edge_type_maps, MemberOf, ManagedBy,Published,ScheduledOn,LocatedIn,References,Uses,BelongsTo,HasImage, HasLink +from .common import Message, Result, Text, Actor, Author, Object, Procedure, Condition, Event, Location, Datetime, Year, Month, Day, Week, Concept, Image, Web, Category, ExecutedBy, ScheduledOn, PartOf, DependsOn, Triggers, LocatedIn, HasDetail, entity_type, edge_type, edge_type_maps from .ingest import AddEntityNodeRequest, AddMessagesRequest, AddTextsRequest from .retrieve import NodeResult, FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults @@ -15,26 +15,30 @@ 'GetMemoryResponse', 'AddTextsRequest', 'Text', - 'Person', - 'Organization', - 'Product', - 'Concept', + 'Actor', + 'Author', + 'Object', + 'Procedure', + 'Condition', + 'Event', 'Location', 'Datetime', - 'Category', + 'Year', + 'Month', + 'Day', + 'Week', + 'Concept', 'Image', 'Web', - 'entity_type', - 'edge_type', - 'edge_type_maps', - 'MemberOf', - 'ManagedBy', - 'Published', + 'Category', + 'ExecutedBy', 'ScheduledOn', + 'PartOf', + 'DependsOn', + 'Triggers', 'LocatedIn', - 'References', - 'Uses', - 'BelongsTo', - 'HasImage', - 'HasLink' + 'HasDetail', + 'entity_type', + 'edge_type', + 'edge_type_maps' ] diff --git a/server/graph_service/dto/common.py b/server/graph_service/dto/common.py index 434a06d3f..c0f259bf3 100644 --- a/server/graph_service/dto/common.py +++ b/server/graph_service/dto/common.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Literal +from typing import Literal, Optional from graphiti_core.utils.datetime_utils import utc_now from pydantic import BaseModel, Field @@ -38,128 +38,160 @@ class Text(BaseModel): default='', description='The description of the source of the text' ) -class Person(BaseModel): - """PEOPLE: 성함 및 직함 포함 (예: '홍길동 팀장')""" - full_name: str = Field(..., description="Full name and title") +class Actor(BaseModel): + """ACTOR: 행위 주체 및 책임자 (부서, 기관, 업체 등)""" + actor_name: str = Field(..., description="기관명, 부서명 또는 담당 주체") + description: str = Field(..., description="역할 및 책임에 대한 상세 설명") -class Organization(BaseModel): - """ORGANIZATIONS: 정식 기관/업체 명칭 (예: '단지1 로비')""" - org_name: str = Field(..., description="Normalized legal name") +class Author(BaseModel): + """AUTHOR: 문서의 작성자 (개인 성함 및 직함 포함)""" + author_name: str = Field(..., description="작성자 성함 및 직함") + description: Optional[str] = Field(None, description="소속 부서 등 추가 정보") -class Asset(BaseModel): - """ASSETS: 공지사항, 게시물 등 주요 정보 객체 (가장 중심이 되는 노드)""" - asset_name: str = Field(..., description="Title of the notice or post") - web_url: str | None = Field(None, description="Related link URL") - image_url: str | None = Field(None, description="Main image URL") +class Object(BaseModel): + """OBJECT: 공지사항, 문서명, 자산, 식단(메뉴) 등 관리 대상""" + object_name: str = Field(..., description="대상 명칭 (예: '3월 소독 안내문', '제육볶음')") + description: str = Field(..., description="상세 내용 및 증거 문구") -class Product(BaseModel): - """PRODUCTS: 구체적인 제품 이름""" - product_name: str = Field(..., description="Specific product name") +class Procedure(BaseModel): + """PROCEDURE: 신청 방법, 업무 단계, 지침 등 행위의 절차""" + procedure_name: str = Field(..., description="절차 명칭") + description: str = Field(..., description="단계별 상세 가이드라인") -class Service(BaseModel): - """SERVICES: 구체적인 서비스 명칭""" - service_name: str = Field(..., description="Specific service name") +class Condition(BaseModel): + """CONDITION: 법령, 규정, 자격 조건 등 행위의 근거""" + condition_name: str = Field(..., description="규정 또는 조건 명칭") + description: str = Field(..., description="제약 사유 및 법적 근거") -class Concept(BaseModel): - """CONCEPTS: 도메인 전문 용어 또는 개념""" - concept_label: str = Field(..., description="Domain-specific concept") +class Event(BaseModel): + """EVENT: 점검, 사고, 행사 등 특정 시점에 발생하는 사건""" + event_name: str = Field(..., description="사건 명칭") + description: str = Field(..., description="사건의 상세 내용") class Location(BaseModel): - """LOCATIONS: 구체적인 장소 (예: '단지1 로비', '지하주차장')""" - location_name: str = Field(..., description="Specific physical location") + """LOCATION: 물리적 장소 또는 시스템 내 위치""" + location_name: str = Field(..., description="구체적인 장소명") class Datetime(BaseModel): - """TEMPORAL ENTITIES: YYYY-MM-DD HH:mm:ss 형식의 정규화된 날짜 노드""" - date_val: str = Field(..., description="Normalized ISO 8601 string") + """DATETIME: YYYY-MM-DD HH:mm:ss 형식의 정규화된 시점""" + datetime_name: str = Field(..., description="YYYY-MM-DD HH:mm:ss") -class Category(BaseModel): - """CATEGORY: 문서 분류""" - category_name: str = Field(..., description="Name of the category") +class Year(BaseModel): + """YEAR: 연도 노드 (시간 계층 최상위)""" + year_name: str = Field(..., description="예: '2025'") + +class Month(BaseModel): + """MONTH: 월 노드""" + month_name: str = Field(..., description="예: '3'") + +class Day(BaseModel): + """DAY: 일 노드""" + day_name: str = Field(..., description="예: '15'") + +class Week(BaseModel): + """WEEK: 주차 노드 (예: '3월 2주차')""" + week_name: str = Field(..., description="월 및 주차 정보") + +class Concept(BaseModel): + """CONCEPT: 전문 용어 또는 도메인 지식""" + concept_name: str = Field(..., description="개념명") + description: str = Field(..., description="개념의 정의") class Image(BaseModel): - """IMAGE: 첨부 이미지 정보""" - url: str = Field(..., description="Direct image link") - alt_text: str | None = Field(None, description="Description of the image") + """IMAGE: 첨부 이미지""" + image_name: str = Field(..., description="이미지 식별자 또는 파일명") + image_url: str = Field(..., description="이미지 링크 URL") class Web(BaseModel): - """WEB: 외부 링크 정보""" - url: str = Field(..., description="Direct web link") + """WEB: 외부 링크""" + web_name: str = Field(..., description="링크 제목") + web_url: str = Field(..., description="웹 사이트 URL") + + +class Category(BaseModel): + """CATEGORY: 카테고리 또는 게시판 분류""" + category_name: str = Field(..., description="분류 명칭 (예: '공지사항')") -# --- 엔티티 타입 등록 --- -entity_type = { - "Person": Person, - "Organization": Organization, - "Asset": Asset, - "Product": Product, - "Service": Service, - "Concept": Concept, - "Location": Location, - "Datetime": Datetime, - "Category": Category, - "Image": Image, - "Web": Web, -} # --- 엣지(Edge) 클래스 정의 --- -class MemberOf(BaseModel): fact: str = Field(default="MEMBER_OF") -class ManagedBy(BaseModel): fact: str = Field(default="MANAGED_BY") -class Published(BaseModel): fact: str = Field(default="PUBLISHED") -class ScheduledOn(BaseModel): fact: str = Field(default="SCHEDULED_ON") -class LocatedIn(BaseModel): fact: str = Field(default="LOCATED_IN") -class References(BaseModel): fact: str = Field(default="REFERENCES") -class Uses(BaseModel): fact: str = Field(default="USES") -class BelongsTo(BaseModel): fact: str = Field(default="BELONGS_TO") -class HasImage(BaseModel): fact: str = Field(default="HAS_IMAGE") -class HasLink(BaseModel): fact: str = Field(default="HAS_LINK") +class ExecutedBy(BaseModel): + fact: str = Field(default="EXECUTED_BY", description="작성자나 실행 주체 연결") +class ScheduledOn(BaseModel): + fact: str = Field(default="SCHEDULED_ON", description="특정 시점 또는 작성 일시 할당") + +class PartOf(BaseModel): + fact: str = Field(default="PART_OF", description="시간 계층(Day-Month-Year) 구조 형성") + +class DependsOn(BaseModel): + fact: str = Field(default="DEPENDS_ON", description="선행 요건 및 법적 근거 참조") + +class Triggers(BaseModel): + fact: str = Field(default="TRIGGERS", description="인과관계(A가 B를 유발함) 연결") + +class LocatedIn(BaseModel): + fact: str = Field(default="LOCATED_IN", description="물리적/디지털 장소 연결") + +class HasDetail(BaseModel): + fact: str = Field(default="HAS_DETAIL", description="이미지, 링크, 메타데이터 연결") + + +# Entity type +entity_type = { + "Actor": Actor, "Author": Author, "Object": Object, + "Procedure": Procedure, "Condition": Condition, "Event": Event, + "Location": Location, "Datetime": Datetime, "Year": Year, + "Month": Month, "Day": Day, "Week": Week, "Concept": Concept, + "Image": Image, "Web": Web, "Category": Category +} + +# Edge type edge_type = { - "MEMBER_OF": MemberOf, - "MANAGED_BY": ManagedBy, - "PUBLISHED": Published, + "EXECUTED_BY": ExecutedBy, "SCHEDULED_ON": ScheduledOn, + "PART_OF": PartOf, + "DEPENDS_ON": DependsOn, + "TRIGGERS": Triggers, "LOCATED_IN": LocatedIn, - "REFERENCES": References, - "USES": Uses, - "BELONGS_TO": BelongsTo, - "HAS_IMAGE": HasImage, - "HAS_LINK": HasLink + "HAS_DETAIL": HasDetail } +# Edge type map edge_type_maps = { - ("Person", "Organization"): ["MEMBER_OF"], - ("Organization", "Category"): ["MEMBER_OF"], + # 1. 주체 및 작성 (Who) + ("Object", "Author"): ["EXECUTED_BY"], + ("Event", "Author"): ["EXECUTED_BY"], + ("Procedure", "Author"): ["EXECUTED_BY"], + ("Object", "Actor"): ["EXECUTED_BY"], - # 관리 및 발행 관계 - ("Asset", "Organization"): ["MANAGED_BY"], - ("Asset", "Person"): ["MANAGED_BY"], - ("Organization", "Asset"): ["PUBLISHED"], - ("Person", "Asset"): ["PUBLISHED"], + # 2. 시간 계층 (When - PART_OF) + ("Datetime", "Day"): ["PART_OF"], + ("Day", "Week"): ["PART_OF"], + ("Day", "Month"): ["PART_OF"], + ("Week", "Month"): ["PART_OF"], + ("Month", "Year"): ["PART_OF"], - # 시간 및 장소 (게시물 중심 연결) - ("Asset", "Datetime"): ["SCHEDULED_ON"], - ("Concept", "Datetime"): ["SCHEDULED_ON"], - ("Asset", "Location"): ["LOCATED_IN"], - ("Organization", "Location"): ["LOCATED_IN"], + # 3. 일정 할당 (When - SCHEDULE_ON) + ("Object", "Datetime"): ["SCHEDULED_ON"], + ("Event", "Datetime"): ["SCHEDULED_ON"], + ("Procedure", "Datetime"): ["SCHEDULED_ON"], - # 참조 및 분류 - ("Asset", "Asset"): ["REFERENCES"], - ("Asset", "Web"): ["HAS_LINK", "REFERENCES"], - ("Asset", "Category"): ["BELONGS_TO"], + # 4. 인과 관계 및 근거 (Why/How) + ("Object", "Condition"): ["DEPENDS_ON"], + ("Procedure", "Condition"): ["DEPENDS_ON"], + ("Condition", "Actor"): ["DEPENDS_ON"], + ("Event", "Procedure"): ["TRIGGERS"], + ("Event", "Event"): ["TRIGGERS"], - # 이미지 및 미디어 연결 (개선사항 반영) - ("Asset", "Image"): ["HAS_IMAGE"], + # 5. 장소 (Where) + ("Actor", "Location"): ["LOCATED_IN"], + ("Object", "Location"): ["LOCATED_IN"], + ("Event", "Location"): ["LOCATED_IN"], - # 도메인 지식 - ("Organization", "Concept"): ["USES"], - ("Person", "Concept"): ["USES"], - ("Asset", "Concept"): ["USES"], + # 6. 상세 정보 및 미디어 (Metadata) + ("Object", "Image"): ["HAS_DETAIL"], + ("Object", "Web"): ["HAS_DETAIL"], + ("Object", "Concept"): ["HAS_DETAIL"], + ("Object", "Category"): ["HAS_DETAIL"], + ("Event", "Category"): ["HAS_DETAIL"], } -# class Doc(BaseModel): -# uuid: str = Field(..., description='The uuid of the episode') -# group_id: str = Field(..., description='The group id of the episode') -# name: str = Field(..., description='The name of the episode') -# episode_body: str = Field(..., description='The body of the episode') -# reference_time: datetime = Field(default_factory=utc_now, description='The reference time of the episode') -# source: str = Field(..., description='The source of the episode') -# source_description: str = Field(...,) - diff --git a/server/graph_service/dto/retrieve.py b/server/graph_service/dto/retrieve.py index e7a498244..484af806e 100644 --- a/server/graph_service/dto/retrieve.py +++ b/server/graph_service/dto/retrieve.py @@ -36,6 +36,10 @@ class NodeResult(BaseModel): attributes: dict[str, Any] = Field(default={}) created_at: datetime + class Config: + json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()} + + class SearchResults(BaseModel): facts: Optional[list[FactResult]] nodes: Optional[list[NodeResult]] From eee5e45158f50a724e1573af29b60b50f98d0e29 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Wed, 28 Jan 2026 17:41:08 +0900 Subject: [PATCH 5/9] feat: opt normalize node --- graphiti_core/normalize.py | 165 ++++++++++++++++++++++++------------- 1 file changed, 107 insertions(+), 58 deletions(-) diff --git a/graphiti_core/normalize.py b/graphiti_core/normalize.py index dc1f1e027..0b2e24d96 100644 --- a/graphiti_core/normalize.py +++ b/graphiti_core/normalize.py @@ -1,5 +1,12 @@ import numpy as np from sklearn.cluster import DBSCAN +import json +from collections import Counter + +from graphiti_core.nodes import EntityNode +from graphiti_core.utils.datetime_utils import utc_now +from difflib import SequenceMatcher #* 문자열 유사도 계산 +from sklearn.cluster import DBSCAN class EntityResolver: """Production entity resolution with context-aware disambiguation""" @@ -8,88 +15,130 @@ def __init__(self): self.entity_cache = {} self.canonical_map = {} - def compute_entity_similarity(self, entity1, entity2): + def compute_entity_similarity(self, entity1: EntityNode, entity2: EntityNode): """Compute similarity considering both text and semantic context""" # Exact match gets high score - if entity1['surface_form'].lower() == entity2['surface_form'].lower(): + if entity1.name.lower() == entity2.name.lower(): base_score = 0.9 else: # Fuzzy match on surface form - from difflib import SequenceMatcher base_score = SequenceMatcher( None, - entity1['surface_form'].lower(), - entity2['surface_form'].lower() + entity1.name.lower(), + entity2.name.lower() ).ratio() - + + feat1 = entity1.attributes + feat2 = entity2.attributes + # Type mismatch penalty - if entity1['type'] != entity2['type']: + if feat1['type'] != feat2['type']: base_score *= 0.3 + - # Context similarity boost - if 'features' in entity1 and 'features' in entity2: - shared_features = set(entity1['features'].keys()) & set(entity2['features'].keys()) - if shared_features: - # Features match increases confidence - feature_match_score = sum( - 1 for k in shared_features - if entity1['features'][k] == entity2['features'][k] - ) / len(shared_features) - base_score = 0.7 * base_score + 0.3 * feature_match_score + # Context similarity boost + # - 같은 텍스트여도 type에 따라 동일성 판단 보정. + ctx1 = feat1['context'].lower() + ctx2 = feat2['context'].lower() + + if ctx1 and ctx2: + if ctx1 == ctx2: + # 추출된 문장이 완정 동일 + feature_match_score = 1.0 + else: + # 문장이 조금 다른 경우 + feature_match_score = SequenceMatcher(None, ctx1, ctx2).ratio() + base_score = 0.7 * base_score + 0.3 * feature_match_score return base_score - def resolve_entities(self, all_entities, similarity_threshold=0.75): - """Cluster entities into canonical forms using DBSCAN""" - - n = len(all_entities) - if n == 0: - return {} + + + + def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarity_threshold: float = 0.75) -> list[EntityNode]: + """Normalize extracted nodes""" + + if not extracted_nodes: + return [] + + n = len(extracted_nodes) - # Build similarity matrix - similarity_matrix = np.zeros((n, n)) + # 1. 유사도 행렬 구성 (자기 자신은 1.0) + similarity_matrix = np.eye(n) for i in range(n): - for j in range(i+1, n): - sim = self.compute_entity_similarity(all_entities[i], all_entities[j]) - similarity_matrix[i,j] = sim - similarity_matrix[j,i] = sim - - # Convert similarity to distance for DBSCAN + for j in range(i + 1, n): + sim = self.compute_entity_similarity(extracted_nodes[i], extracted_nodes[j]) + similarity_matrix[i, j] = sim + similarity_matrix[j, i] = sim + + # 2. DBSCAN 클러스터링 distance_matrix = 1 - similarity_matrix - - # Cluster entities clustering = DBSCAN( - eps=1-similarity_threshold, - min_samples=1, + eps=1 - similarity_threshold, + min_samples=1, metric='precomputed' ).fit(distance_matrix) + normalized_nodes = [] + - # Create canonical entities - canonical_entities = {} + # 3. 클러스터별 병합 수행 for cluster_id in set(clustering.labels_): cluster_members = [ - all_entities[i] for i, label in enumerate(clustering.labels_) + extracted_nodes[i] for i, label in enumerate(clustering.labels_) if label == cluster_id ] + + if not cluster_members: + continue + + # A. 대표 이름 결정 (빈도수가 가장 높은 이름) + names = [node.name for node in cluster_members] + canonical_name = Counter(names).most_common(1)[0][0] + + # B. 타입(Labels) 합치기 (Set으로 중복 제거) + all_labels = set() + for node in cluster_members: + all_labels.update(node.labels) - # Most common surface form becomes canonical - surface_forms = [e['surface_form'] for e in cluster_members] - canonical_form = max(set(surface_forms), key=surface_forms.count) - - canonical_entities[canonical_form] = { - 'canonical_name': canonical_form, - 'type': cluster_members[0]['type'], - 'variant_forms': list(set(surface_forms)), - 'occurrences': len(cluster_members), - 'contexts': [e['context'] for e in cluster_members[:5]] # Sample contexts - } + # C. 속성 및 특징 누적 (Persistence) + merged_attributes = {} + all_features = [] + for node in cluster_members: + # 기존 extraction_features 수집 + features = node.attributes + if isinstance(features, list): + all_features.extend(features) + else: + all_features.append(features) - # Map all variants to canonical form - for variant in surface_forms: - self.canonical_map[variant] = canonical_form - - return canonical_entities - - def get_canonical_form(self, surface_form): - """Get canonical entity name for any surface form""" - return self.canonical_map.get(surface_form, surface_form) \ No newline at end of file + unique_features = [] + seen_json = set() + + for feature in all_features: + # Pydantic 모델인 경우 dict로 변환, 아니면 그대로 사용 + feature_dict = feature.model_dump() if hasattr(feature, 'model_dump') else feature + + # 중복 체크를 위한 직렬화 (default=str은 datetime 대비) + feature_json = json.dumps(feature_dict, sort_keys=True, default=str) + + if feature_json not in seen_json: + seen_json.add(feature_json) + unique_features.append(feature_dict) + + merged_attributes["extraction_features"] = json.dumps(unique_features, ensure_ascii=False) + merged_attributes["occurrence_count"] = len(cluster_members) + + + # D. 새로운 EntityNode 생성 (병합본) + new_node = EntityNode( + name=canonical_name, + group_id=cluster_members[0].group_id, # 같은 에피소드 내이므로 동일 + labels=list(all_labels), + summary='', + created_at=utc_now(), + attributes=merged_attributes + ) + normalized_nodes.append(new_node) + + return normalized_nodes + From b71be4a844c82a8495376f5878c133ba96864f34 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Tue, 3 Feb 2026 17:56:20 +0900 Subject: [PATCH 6/9] feat: opt normalize node --- graphiti_core/normalize.py | 5 +- tests/test_normalize.py | 137 +++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 4 deletions(-) create mode 100644 tests/test_normalize.py diff --git a/graphiti_core/normalize.py b/graphiti_core/normalize.py index 0b2e24d96..bababaa92 100644 --- a/graphiti_core/normalize.py +++ b/graphiti_core/normalize.py @@ -2,18 +2,15 @@ from sklearn.cluster import DBSCAN import json from collections import Counter - from graphiti_core.nodes import EntityNode from graphiti_core.utils.datetime_utils import utc_now from difflib import SequenceMatcher #* 문자열 유사도 계산 -from sklearn.cluster import DBSCAN class EntityResolver: """Production entity resolution with context-aware disambiguation""" def __init__(self): - self.entity_cache = {} - self.canonical_map = {} + self.similarity_threshold = 0.75 def compute_entity_similarity(self, entity1: EntityNode, entity2: EntityNode): """Compute similarity considering both text and semantic context""" diff --git a/tests/test_normalize.py b/tests/test_normalize.py new file mode 100644 index 000000000..bb2a09154 --- /dev/null +++ b/tests/test_normalize.py @@ -0,0 +1,137 @@ +import unittest +import sys +import os +import json +from datetime import datetime, timezone +from difflib import SequenceMatcher +from unittest.mock import Mock +from graphiti_core.nodes import EntityNode +from graphiti_core.normalize import EntityResolver + +class EntityResolverTests(unittest.TestCase): + ''' + test entity resolver + + ## runs + $ python3.12 -m venv venv + $ source venv/bin/activate + $ pip3 install graphiti-core scikit-learn + $ python3 -m unittest tests.test_normalize + ''' + def setUp(self): + self.resolver = EntityResolver() + self.base_time = datetime.now() + self.group_id = 'site:1000039' + + def create_node(self, extracted_node={}, uuid=None): + + node_args = { + 'uuid': uuid, + 'name': extracted_node['name'], + 'labels': extracted_node['labels'], + 'group_id': self.group_id, + 'created_at': datetime(2026, 1, 28, 5, 4, 39, 360956, tzinfo=timezone.utc), + 'name_embedding': None, + 'summary': '', + 'attributes': extracted_node['attribute'], + } + if uuid is None: + del node_args['uuid'] + + return EntityNode(**node_args) + + def test_compute_entity_similarity(self): + '''Similarity calculation: Same name but different type (penalty applied)''' + node1 = self.create_node({"name": "생활지원센터", "labels": ["AUTHOR"], "attribute": {"type": "AUTHOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}) + node2 = self.create_node({"name": "생활지원센터", "labels": ["ACTOR"], "attribute": {"type": "ACTOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}) + + # Name match -> base_score = 0.9 + # Type mismatch penalty -> base_score *= 0.3 (0.27) + ctx_sim = SequenceMatcher(None, node1.attributes['context'].lower(), node2.attributes['context'].lower()).ratio() + expected_sim = 0.7 * (0.9 * 0.3) + 0.3 * ctx_sim + sim1 = self.resolver.compute_entity_similarity(node1, node2) + self.assertEqual(sim1, 0.489) + self.assertAlmostEqual(sim1, expected_sim, places=2) + self.assertLess(sim1, 0.75, "Type mismatch should prevent merging") + + # Similarity calculation: Slightly different names (Fuzzy match) + node3 = self.create_node({"name": "정기소독", "labels": ["EVENT"], "attribute": {"type": "EVENT", "context": "[정기소독] 2월 정기소독 실시안내"}}) + node4 = self.create_node({"name": "정기소독 실시", "labels": ["EVENT"], "attribute": {"type": "EVENT", "context": "[정기소독] 2월 정기소독 실시안내"}}) + + name_sim = SequenceMatcher(None, node3.name.lower(), node4.name.lower()).ratio() + ctx_sim2 = SequenceMatcher(None, node3.attributes['context'].lower(), node4.attributes['context'].lower()).ratio() + expected_sim2 = 0.7 * name_sim + 0.3 * ctx_sim2 + + sim2 = self.resolver.compute_entity_similarity(node3, node4) + self.assertEqual(sim2, 0.8090909090909091) + self.assertAlmostEqual(sim2, expected_sim2, places=2) + + + def test_normalize_nodes(self): + '''Test normalization of extracted nodes using DBSCAN clustering.''' + + # 1. Prepare test data + extracted_data = { + "nodes": [ + {"name": "공지사항", "labels": ["CATEGORY"], "attribute": {"type": "CATEGORY", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}, + {"name": "생활지원센터", "labels": ["AUTHOR"], "attribute": {"type": "AUTHOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}, + {"name": "생활지원센터", "labels": ["ACTOR"], "attribute": {"type": "ACTOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}, + {"name": "정기소독", "labels": ["EVENT"], "attribute": {"type": "EVENT", "context": "[정기소독] 2월 정기소독 실시안내"}}, + {"name": "정기소독 실시", "labels": ["EVENT"], "attribute": {"type": "EVENT", "context": "[정기소독] 2월 정기소독 실시안내"}}, + {"name": "2", "labels": ["MONTH"], "attribute": {"type": "MONTH", "context": "2월 정기소독 실시안내"}}, + {"name": "정기소독 실시안내", "labels": ["OBJECT"], "attribute": {"type": "OBJECT", "context": "[정기소독] 2월 정기소독 실시안내"}}, + {"name": "notice", "labels": ["CATEGORY"], "attribute": {"type": "CATEGORY", "context": "카테고리 notice"}}, + {"name": "2025-02-19 16:21:22", "labels": ["DATETIME"], "attribute": {"type": "DATETIME", "context": "작성일 2025-02-19 16:21:22"}}, + {"name": "2025", "labels": ["YEAR"], "attribute": {"type": "YEAR", "context": "작성일 2025-02-19 16:21:22"}}, + {"name": "2", "labels": ["MONTH"], "attribute": {"type": "MONTH", "context": "작성일 2025-02-19 16:21:22"}}, + {"name": "19", "labels": ["DAY"], "attribute": {"type": "DAY", "context": "작성일 2025-02-19 16:21:22"}}, + {"name": "이미지1 (https://image.teset)", "labels": ["IMAGE"], "attribute": {"type": "IMAGE", "context": "첨부 이미지 - [이미지1](https://image.test)"}}, + {"name": "/post/11/board/22", "labels": ["WEB"], "attribute": {"type": "WEB", "context": "링크 /post/11/board/22"}}, + ] + } + + # Verify default threshold + self.assertEqual(self.resolver.similarity_threshold, 0.75) + + + # Sanity check: create_node helper + test_uuid = '5e9f662a-7a65-4580-b9b8-9384a0c731d6' + node1 = self.create_node({"name": "생활지원센터", "labels": ["AUTHOR"], "attribute": {"type": "AUTHOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}}, uuid=test_uuid) + self.assertEqual(node1, EntityNode( + uuid=test_uuid, + name='생활지원센터', + labels=["AUTHOR"], + group_id=self.group_id, + created_at=datetime(2026, 1, 28, 5, 4, 39, 360956, tzinfo=timezone.utc), + name_embedding=None, + summary='', + attributes={"type": "AUTHOR", "context": "[공지사항 생활지원센터] [정기소독] 2월 정기소독 실시안내"}, + )) + + # 2. Create initial nodes from sample data + initial_nodes = [ + self.create_node(n) for n in extracted_data['nodes'] + ] + + # 3. Run normalization + # "정기소독" and "정기소독 실시" should merge because their similarity (0.809) > threshold (0.75) + normalized_nodes = self.resolver.normalize_extracted_nodes(initial_nodes, 0.75) + + # 4. Verify results + # Expected count: 14 initial nodes. "정기소독" and "정기소독 실시" merge -> 13 nodes. + self.assertEqual(len(normalized_nodes), 13) + + # Verify '정기소독' (EVENT) was merged (occurrence_count should be 2) + event_node = next((n for n in normalized_nodes if n.name == "정기소독")) + self.assertIsNotNone(event_node) + self.assertEqual(event_node.attributes.get('occurrence_count'), 2) + + # Verify '생활지원센터' (AUTHOR) was NOT merged with '생활지원센터' (ACTOR) due to type mismatch + author_nodes = [n for n in normalized_nodes if n.name == '생활지원센터'] + self.assertEqual(len(author_nodes), 2) + self.assertEqual(author_nodes[0].attributes.get('occurrence_count'), 1) + self.assertEqual(author_nodes[1].attributes.get('occurrence_count'), 1) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 70766d40fcf439ccdb06150a8d89ddcbac948e5f Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Thu, 5 Feb 2026 10:15:03 +0900 Subject: [PATCH 7/9] feat: opt normalize entity w/ episode v2 --- graphiti_core/errors.py | 7 + graphiti_core/graphiti.py | 182 +++++++++++++++++- graphiti_core/normalize.py | 38 ++-- .../utils/maintenance/edge_operations.py | 77 ++++++++ .../utils/maintenance/node_operations.py | 117 +++++++++++ 5 files changed, 404 insertions(+), 17 deletions(-) diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index 3bbb9a94f..05439827b 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -81,3 +81,10 @@ class GroupIdValidationError(GraphitiError): def __init__(self, group_id: str): self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores' super().__init__(self.message) + +class GroupIdNotFoundError(GraphitiError): + """Raised when a group_id is not found.""" + + def __init__(self): + self.message = 'group_id not found' + super().__init__(self.message) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index b5acbd4cf..bffec6189 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -35,7 +35,7 @@ create_entity_edge_embeddings, ) from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder -from graphiti_core.errors import NodeNotFoundError +from graphiti_core.errors import GroupIdNotFoundError, NodeNotFoundError from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( get_default_group_id, @@ -86,6 +86,7 @@ extract_edges, resolve_extracted_edge, resolve_extracted_edges, + resolve_extracted_edges_v2, ) from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, @@ -95,6 +96,7 @@ extract_attributes_from_nodes, extract_nodes, resolve_extracted_nodes, + resolve_extracted_nodes_v2, ) from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types @@ -893,6 +895,184 @@ async def add_episode_endpoint(episode_data: EpisodeData): span.set_status('error', str(e)) span.record_exception(e) raise e + + def create_node(self, node, group_id: str): + labels: list[str] = list({'Entity', str(node['labels'][0])}) + + return EntityNode( + name=node['name'], + group_id=group_id, + labels=labels, + summary='', + created_at=utc_now(), + name_embedding=node.get('name_embedding'), + attributes=node.get('attribute') + ) + + def create_summary(self, node, nodes: list[EntityNode], group_id: str): + source_node_idx = node['id'] + source_node = nodes[source_node_idx] + + return EntityNode( + uuid=source_node.uuid, + name=source_node.name, + group_id=source_node.group_id, + labels=source_node.labels, + summary=node['summary'] if node.get('summary') else '', + created_at=utc_now(), + name_embedding=source_node.name_embedding, + attributes=source_node.attributes + ) + + def create_edge(self, edge, nodes: list[EntityNode], group_id: str, episode: EpisodicNode): + # 안전하게 인덱스 추출 + s_idx = int(edge.get('sourceNodeId', 0)) + t_idx = int(edge.get('targetNodeId', 0)) + + # 인덱스 범위 초과 방지 + source_uuid = nodes[s_idx].uuid if s_idx < len(nodes) else "UNKNOWN" + target_uuid = nodes[t_idx].uuid if t_idx < len(nodes) else "UNKNOWN" + + return EntityEdge( + source_node_uuid=source_uuid, + target_node_uuid=target_uuid, + name=edge.get('type', 'RELATES_TO'), + group_id=group_id, + fact=edge.get('relation', ''), + episodes=[episode.uuid], + created_at=utc_now(), + fact_embedding=edge.get('fact_embedding'), + valid_at=episode.valid_at or utc_now() + ) + + + async def add_episode_v2( + self, + group_id: str, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + extract_nodes: list[object], + extract_edges: list[object], + summary_nodes: list[object], + source: EpisodeType = EpisodeType.text, + entity_types: dict[str, type[BaseModel]] | None = None, + excluded_entity_types: list[str] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, + edge_type_map: dict[tuple[str, str], list[str]] | None = None, + ) -> AddEpisodeResults: + """ + [node] 노드 추출; A + [graphiti] 노드 정규화 w/ A + [node] 엣지 추출 w/ A + [node] 노드 요약 w/ A + [graphiti] 저장 + """ + start = time() + now = utc_now() + + # STEP.0-1 valid parameters + if group_id: + validate_group_id(group_id) + if group_id != self.driver._database: + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver + else: + raise GroupIdNotFoundError + + + with self.tracer.start_span('add_episode_v2') as span: + try: + + # Get or create episode + episode = EpisodicNode( + name=name, + group_id=group_id, + labels=[], + source=source, + content=episode_body, + source_description=source_description, + created_at=now, + valid_at=reference_time, + + ) + + edge_type_map_default = ( + {('Entity', 'Entity'): list(edge_types.keys())} + if edge_types is not None + else {('Entity', 'Entity'): []} + ) + + # STEP. 노드 추출 + extracted_nodes: list[EntityNode] = [ self.create_node(node, group_id) for node in extract_nodes] + + + # STEP. 노드 정규화 + nodes, uuid_map, _ = await resolve_extracted_nodes_v2( + self.clients, + extracted_nodes, + episode, + entity_types = entity_types, + ) + + # STEP. 엣지 추출 + extracted_edges: list[EntityEdge] = [ self.create_edge(edge, nodes, group_id, episode) for edge in extract_edges] + edges = resolve_edge_pointers(extracted_edges, uuid_map) + resolved_edges, invalidated_edges = await resolve_extracted_edges_v2( + self.clients, + edges, + episode, + nodes, + edge_types or {}, + edge_type_map or edge_type_map_default, + ) + + entity_edges = resolved_edges + invalidated_edges + + # STEP. 노드 속성 추출 (Summary) + hydrated_nodes: list[EntityNode] = [ self.create_summary(node, nodes, group_id) for node in summary_nodes] + + # STEP. Process and Save Episode + # - _process_episode_data: 에피소드 노드와 엣지, 그리고 처리된 엔티티 노드와 엣지들을 DB에 저장합니다. + # - build_episodic_edges: 에피소드 노드와 엔티티 노드들을 연결하는 'MENTIONS' 엣지를 생성합니다. + # - add_nodes_and_edges_bulk: 노드, 엣지, 에피소드 데이터를 일괄적으로 DB에 저장(Merge/Create)합니다. + episodic_edges, episode = await self._process_episode_data( + episode, hydrated_nodes, entity_edges, now + ) + end = time() + + # STEP. Add span attributes (logging) + span.add_attributes( + { + 'episode.uuid': episode.uuid, + 'episode.source': source.value, + 'episode.reference_time': reference_time.isoformat(), + 'group_id': group_id, + 'node.count': len(hydrated_nodes), + 'edge.count': len(entity_edges), + 'edge.invalidated_count': len(invalidated_edges), + 'entity_types.count': len(entity_types) if entity_types else 0, + 'edge_types.count': len(edge_types) if edge_types else 0, + 'duration_ms': (end - start) * 1000, + } + ) + + # STEP. Return Saved Episode + return AddEpisodeResults( + episode=episode, + episodic_edges=episodic_edges, + nodes=hydrated_nodes, + edges=entity_edges, + communities=[], + community_edges=[], + ) + + except Exception as e: + span.set_status('error', str(e)) + span.record_exception(e) + raise e + async def add_episode_bulk( self, diff --git a/graphiti_core/normalize.py b/graphiti_core/normalize.py index bababaa92..2e49f1607 100644 --- a/graphiti_core/normalize.py +++ b/graphiti_core/normalize.py @@ -1,10 +1,13 @@ +from typing import Any import numpy as np from sklearn.cluster import DBSCAN import json from collections import Counter + from graphiti_core.nodes import EntityNode from graphiti_core.utils.datetime_utils import utc_now from difflib import SequenceMatcher #* 문자열 유사도 계산 +from sklearn.cluster import DBSCAN class EntityResolver: """Production entity resolution with context-aware disambiguation""" @@ -30,14 +33,14 @@ def compute_entity_similarity(self, entity1: EntityNode, entity2: EntityNode): feat2 = entity2.attributes # Type mismatch penalty - if feat1['type'] != feat2['type']: + if feat1.get('type') != feat2.get('type'): base_score *= 0.3 # Context similarity boost # - 같은 텍스트여도 type에 따라 동일성 판단 보정. - ctx1 = feat1['context'].lower() - ctx2 = feat2['context'].lower() + ctx1 = feat1.get('context', '').lower() + ctx2 = feat2.get('context', '').lower() if ctx1 and ctx2: @@ -51,8 +54,9 @@ def compute_entity_similarity(self, entity1: EntityNode, entity2: EntityNode): return base_score + def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarity_threshold: float = 0.75): + self.similariy_threshold = similarity_threshold or self.similarity_threshold - def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarity_threshold: float = 0.75) -> list[EntityNode]: """Normalize extracted nodes""" if not extracted_nodes: @@ -75,18 +79,16 @@ def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarit min_samples=1, metric='precomputed' ).fit(distance_matrix) - normalized_nodes = [] - + normalized_nodes: list[EntityNode] = [] + node_to_group_map = {} # {원본_index: 최종_병합_노드} + # 3. 클러스터별 병합 수행 for cluster_id in set(clustering.labels_): - cluster_members = [ - extracted_nodes[i] for i, label in enumerate(clustering.labels_) - if label == cluster_id - ] - if not cluster_members: - continue + # 해당 클러스터에 속한 원본 인덱스들 추출 + member_indices = [i for i, label in enumerate(clustering.labels_) if label == cluster_id] + cluster_members = [extracted_nodes[i] for i in member_indices] # A. 대표 이름 결정 (빈도수가 가장 높은 이름) names = [node.name for node in cluster_members] @@ -112,10 +114,8 @@ def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarit seen_json = set() for feature in all_features: - # Pydantic 모델인 경우 dict로 변환, 아니면 그대로 사용 feature_dict = feature.model_dump() if hasattr(feature, 'model_dump') else feature - # 중복 체크를 위한 직렬화 (default=str은 datetime 대비) feature_json = json.dumps(feature_dict, sort_keys=True, default=str) if feature_json not in seen_json: @@ -123,13 +123,14 @@ def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarit unique_features.append(feature_dict) merged_attributes["extraction_features"] = json.dumps(unique_features, ensure_ascii=False) + merged_attributes["occurrence_count"] = len(cluster_members) # D. 새로운 EntityNode 생성 (병합본) new_node = EntityNode( name=canonical_name, - group_id=cluster_members[0].group_id, # 같은 에피소드 내이므로 동일 + group_id=cluster_members[0].group_id, labels=list(all_labels), summary='', created_at=utc_now(), @@ -137,5 +138,10 @@ def normalize_extracted_nodes(self, extracted_nodes: list[EntityNode], similarit ) normalized_nodes.append(new_node) - return normalized_nodes + # 핵심: 이 클러스터에 속했던 모든 원본 인덱스가 이 new_node를 바라보게 함 + for idx in member_indices: + node_to_group_map[idx] = new_node + + return normalized_nodes, node_to_group_map + diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 41ce31d1d..fcda30301 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -18,6 +18,7 @@ from datetime import datetime from time import time +import numpy as np from pydantic import BaseModel from typing_extensions import LiteralString @@ -392,6 +393,82 @@ async def resolve_extracted_edges( return resolved_edges, invalidated_edges +async def resolve_extracted_edges_v2( + clients: GraphitiClients, + extracted_edges: list[EntityEdge], + episode: EpisodicNode, + entities: list[EntityNode], + edge_types: dict[str, type[BaseModel]], + edge_type_map: dict[tuple[str, str], list[str]],, +) -> tuple[list[EntityEdge], list[EntityEdge]]: + if not extracted_edges: + return [], [] + + driver = clients.driver + embedder = clients.embedder + + await create_entity_edge_embeddings(embedder, extracted_edges) + + # 2. 동일 노드 쌍 사이의 기존 엣지들 조회 + valid_edges_list = await semaphore_gather( + *[ + EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) + for edge in extracted_edges + ] + ) + + # 3. 기존 엣지들의 임베딩(숫자 지문) 로드 (병렬 처리) + all_existing_edges: list[EntityEdge] = [edge for sublist in valid_edges_list for edge in sublist] + if all_existing_edges: + # load_fact_embedding을 통해 DB에 저장된 벡터를 가져옴 + await semaphore_gather( + *[edge.load_fact_embedding(driver) for edge in all_existing_edges] + ) + + resolved_edges: list[EntityEdge] = [] + invalidated_edges: list[EntityEdge] = [] + + # 4. 비교 로직 (Name 일치 + Fact 유사도) + for ext_edge, existing_edges in zip(extracted_edges, valid_edges_list): + is_duplicate = False + print(f"New Edge Embedding: {ext_edge.fact_embedding[:5] if ext_edge.fact_embedding else 'MISSING'}") + + for ex_edge in existing_edges: + # 2. DB에서 가져온 기존 엣지의 임베딩 확인 + print(f"Existing Edge Embedding: {ex_edge.fact_embedding[:5] if ex_edge.fact_embedding else 'MISSING'}") + # 4-1. Name(타입)이 다르면 비교할 가치 없음 + if ext_edge.name != ex_edge.name: + continue + + # 4-2. Fact(문장)의 의미적 유사도 계산 + similarity = 0.0 + if ext_edge.fact_embedding and ex_edge.fact_embedding: + similarity = np.dot(ext_edge.fact_embedding, ex_edge.fact_embedding) + + if similarity > 0.7: + if episode.uuid not in ex_edge.episodes: + ex_edge.episodes.append(episode.uuid) + + # 기존 엣지 정보를 유지하고 새 에피소드만 연결 + resolved_edges.append(ex_edge) + is_duplicate = True + break + + + # 중복이 아니면 새로운 관계(엣지)로 추가 + if not is_duplicate: + resolved_edges.append(ext_edge) + + # 5. UUID 기반 최종 중복 제거 + seen_uuids = set() + final_resolved = [] + for e in resolved_edges: + if e.uuid not in seen_uuids: + final_resolved.append(e) + seen_uuids.add(e.uuid) + + return final_resolved, invalidated_edges + def resolve_edge_contradictions( resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 3b9347fc5..c694c7f32 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -373,6 +373,123 @@ async def _resolve_with_llm( state.duplicate_pairs.append((extracted_node, resolved_node)) +async def _resolve_without_llm( + clients: GraphitiClients, + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, +) -> None: + if not state.unresolved_indices: + return + + # 1. Resolver 준비 (clients에 있으면 쓰고 없으면 새로 생성) + resolver = clients.resolver + + # 2. 미해결 노드들 리스트업 + unresolved_nodes = [extracted_nodes[i] for i in state.unresolved_indices] + + # 3. 내부 정규화 및 매핑 정보 획득 + normalized_new_nodes, node_to_group_map = resolver.normalize_extracted_nodes(unresolved_nodes) + + # 4. 정규화된 각 그룹(Cluster)을 기존 DB 노드와 비교 + group_to_existing_map = {} + for norm_node in normalized_new_nodes: + best_match_node = None + max_sim = 0 + + # 기존 DB 노드(Candidate)와 비교 + for candidate in indexes.existing_nodes: + sim = resolver.compute_entity_similarity(norm_node, candidate) + print("> norm_node name", norm_node.name) + print("> existing_nodes name", candidate.name) + + print("> last sim", sim) + if sim > max_sim and sim >= 0.99999: + max_sim = sim + best_match_node = candidate + + # 기존 노드가 있으면 그것을, 없으면 병합된 새 노드 그대로 사용 + group_to_existing_map[norm_node] = best_match_node if best_match_node else norm_node + + # 5. 최종 UUID 및 Resolved Nodes 업데이트 + for i, original_idx in enumerate(state.unresolved_indices): + orig_node = extracted_nodes[original_idx] + target_group_node = node_to_group_map[i] + final_resolved_node = group_to_existing_map[target_group_node] + + state.resolved_nodes[original_idx] = final_resolved_node + state.uuid_map[orig_node.uuid] = final_resolved_node.uuid + + # CASE 1: 기존 DB 노드와 매칭된 경우 (Global Dedupe) + if final_resolved_node in indexes.existing_nodes: + # DB 노드에 새로운 속성을 병합하는 로직 필요 + continue + + # CASE 2: DB에는 없지만, 새 노드들끼리 그룹화된 경우 (Local Dedupe) + if final_resolved_node.uuid != orig_node.uuid: + state.duplicate_pairs.append((orig_node, final_resolved_node)) + + # 6 + state.unresolved_indices = [] + logger.info(f"Resolution done: {len(normalized_new_nodes)} clusters matched to DB.") + + +async def resolve_extracted_nodes_v2( + clients: GraphitiClients, + extracted_nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, + existing_nodes_override: list[EntityNode] | None = None, +) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: + """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the Normalize Node Function.""" + # llm_client = clients.llm_client + driver = clients.driver + existing_nodes = await _collect_candidate_nodes( # 검색된 기존재 노드들의 unique by uuid한 list[EntityNode]. + clients, + extracted_nodes, + existing_nodes_override, + ) + + indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes) # 중복 제거 후보 인덱스 + + state = DedupResolutionState( + resolved_nodes=[None] * len(extracted_nodes), + uuid_map={}, + unresolved_indices=[], + duplicate_pairs=[] + ) + + _resolve_with_similarity(extracted_nodes, indexes, state) + + await _resolve_without_llm( + clients, + extracted_nodes, + indexes, + state + ) + + for idx, node in enumerate(extracted_nodes): + if state.resolved_nodes[idx] is None: + state.resolved_nodes[idx] = node + state.uuid_map[node.uuid] = node.uuid + + logger.debug( + 'Resolved nodes: %s', + [(node.name, node.uuid) for node in state.resolved_nodes if node is not None], + ) + + new_node_duplicates: list[ + tuple[EntityNode, EntityNode] + ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs) + + return ( + [node for node in state.resolved_nodes if node is not None], + state.uuid_map, + new_node_duplicates, + ) + + async def resolve_extracted_nodes( clients: GraphitiClients, extracted_nodes: list[EntityNode], From 0d7c684b7b9f472d2a71562dbcb41b00f15bb9a1 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Tue, 10 Feb 2026 20:28:52 +0900 Subject: [PATCH 8/9] feat: opt graph ingest API --- graphiti_core/graphiti.py | 138 ++++++++++--------------- server/graph_service/dto/__init__.py | 6 +- server/graph_service/dto/ingest.py | 15 ++- server/graph_service/routers/ingest.py | 38 ++++++- 4 files changed, 109 insertions(+), 88 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index bffec6189..3b82d2a45 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -17,6 +17,7 @@ import logging from datetime import datetime from time import time +from typing import Any from dotenv import load_dotenv from pydantic import BaseModel @@ -130,6 +131,10 @@ class AddTripletResults(BaseModel): edges: list[EntityEdge] +class NormalizeNodeResults(BaseModel): + nodes: list[EntityNode] + + class Graphiti: def __init__( self, @@ -896,55 +901,65 @@ async def add_episode_endpoint(episode_data: EpisodeData): span.record_exception(e) raise e - def create_node(self, node, group_id: str): - labels: list[str] = list({'Entity', str(node['labels'][0])}) + + def create_node(self, node: dict[str, Any], group_id: str): return EntityNode( name=node['name'], group_id=group_id, - labels=labels, + labels=node['labels'], summary='', created_at=utc_now(), - name_embedding=node.get('name_embedding'), - attributes=node.get('attribute') - ) - - def create_summary(self, node, nodes: list[EntityNode], group_id: str): - source_node_idx = node['id'] - source_node = nodes[source_node_idx] - - return EntityNode( - uuid=source_node.uuid, - name=source_node.name, - group_id=source_node.group_id, - labels=source_node.labels, - summary=node['summary'] if node.get('summary') else '', - created_at=utc_now(), - name_embedding=source_node.name_embedding, - attributes=source_node.attributes + name_embedding=node['nameEmbed'], + attributes=node['attributes'] ) - def create_edge(self, edge, nodes: list[EntityNode], group_id: str, episode: EpisodicNode): - # 안전하게 인덱스 추출 + def create_edge(self, edge: dict[str, Any], nodes: list[EntityNode], group_id: str, episode: EpisodicNode): s_idx = int(edge.get('sourceNodeId', 0)) t_idx = int(edge.get('targetNodeId', 0)) - # 인덱스 범위 초과 방지 source_uuid = nodes[s_idx].uuid if s_idx < len(nodes) else "UNKNOWN" target_uuid = nodes[t_idx].uuid if t_idx < len(nodes) else "UNKNOWN" return EntityEdge( source_node_uuid=source_uuid, target_node_uuid=target_uuid, - name=edge.get('type', 'RELATES_TO'), + name=edge['type'], group_id=group_id, - fact=edge.get('relation', ''), + fact=edge['relation'], episodes=[episode.uuid], created_at=utc_now(), - fact_embedding=edge.get('fact_embedding'), + fact_embedding=edge['relationEmbed'], valid_at=episode.valid_at or utc_now() ) + + async def normalize_node_v2( + self, + group_id: str, + extracted_nodes: list[dict[str, Any]], + ) -> NormalizeNodeResults: + print('> group id', group_id) + + # valid parameters + if group_id: + validate_group_id(group_id) + if group_id != self.driver._database: + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver + else: + raise GroupIdNotFoundError + entity_nodes = [self.create_node(node, group_id) for node in extracted_nodes] + + # 노드 정규화 + nodes, uuid_map, _ = await resolve_extracted_nodes_v2( + self.clients, + entity_nodes, + ) + + return NormalizeNodeResults( + nodes=nodes + ) async def add_episode_v2( self, @@ -953,26 +968,14 @@ async def add_episode_v2( episode_body: str, source_description: str, reference_time: datetime, - extract_nodes: list[object], - extract_edges: list[object], - summary_nodes: list[object], + extract_nodes: list[dict[str, Any]], + extract_edges: list[dict[str, Any]], source: EpisodeType = EpisodeType.text, - entity_types: dict[str, type[BaseModel]] | None = None, - excluded_entity_types: list[str] | None = None, - edge_types: dict[str, type[BaseModel]] | None = None, - edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddEpisodeResults: - """ - [node] 노드 추출; A - [graphiti] 노드 정규화 w/ A - [node] 엣지 추출 w/ A - [node] 노드 요약 w/ A - [graphiti] 저장 - """ start = time() now = utc_now() - # STEP.0-1 valid parameters + # STEP.0 valid parameters if group_id: validate_group_id(group_id) if group_id != self.driver._database: @@ -985,7 +988,7 @@ async def add_episode_v2( with self.tracer.start_span('add_episode_v2') as span: try: - # Get or create episode + # STEP.1 Get or create episode episode = EpisodicNode( name=name, group_id=group_id, @@ -998,47 +1001,16 @@ async def add_episode_v2( ) - edge_type_map_default = ( - {('Entity', 'Entity'): list(edge_types.keys())} - if edge_types is not None - else {('Entity', 'Entity'): []} - ) - - # STEP. 노드 추출 - extracted_nodes: list[EntityNode] = [ self.create_node(node, group_id) for node in extract_nodes] + # STEP.2 create nodes and edges + entity_nodes = [self.create_node(node, group_id) for node in extract_nodes] + entity_edges = [self.create_edge(edge, entity_nodes, group_id, episode) for edge in extract_edges] - - # STEP. 노드 정규화 - nodes, uuid_map, _ = await resolve_extracted_nodes_v2( - self.clients, - extracted_nodes, - episode, - entity_types = entity_types, - ) - - # STEP. 엣지 추출 - extracted_edges: list[EntityEdge] = [ self.create_edge(edge, nodes, group_id, episode) for edge in extract_edges] - edges = resolve_edge_pointers(extracted_edges, uuid_map) - resolved_edges, invalidated_edges = await resolve_extracted_edges_v2( - self.clients, - edges, - episode, - nodes, - edge_types or {}, - edge_type_map or edge_type_map_default, - ) - - entity_edges = resolved_edges + invalidated_edges - - # STEP. 노드 속성 추출 (Summary) - hydrated_nodes: list[EntityNode] = [ self.create_summary(node, nodes, group_id) for node in summary_nodes] - - # STEP. Process and Save Episode + # STEP.3 Process and Save Episode # - _process_episode_data: 에피소드 노드와 엣지, 그리고 처리된 엔티티 노드와 엣지들을 DB에 저장합니다. # - build_episodic_edges: 에피소드 노드와 엔티티 노드들을 연결하는 'MENTIONS' 엣지를 생성합니다. # - add_nodes_and_edges_bulk: 노드, 엣지, 에피소드 데이터를 일괄적으로 DB에 저장(Merge/Create)합니다. episodic_edges, episode = await self._process_episode_data( - episode, hydrated_nodes, entity_edges, now + episode, entity_nodes, entity_edges, now ) end = time() @@ -1049,11 +1021,8 @@ async def add_episode_v2( 'episode.source': source.value, 'episode.reference_time': reference_time.isoformat(), 'group_id': group_id, - 'node.count': len(hydrated_nodes), + 'node.count': len(entity_nodes), 'edge.count': len(entity_edges), - 'edge.invalidated_count': len(invalidated_edges), - 'entity_types.count': len(entity_types) if entity_types else 0, - 'edge_types.count': len(edge_types) if edge_types else 0, 'duration_ms': (end - start) * 1000, } ) @@ -1062,7 +1031,7 @@ async def add_episode_v2( return AddEpisodeResults( episode=episode, episodic_edges=episodic_edges, - nodes=hydrated_nodes, + nodes=entity_nodes, edges=entity_edges, communities=[], community_edges=[], @@ -1072,7 +1041,8 @@ async def add_episode_v2( span.set_status('error', str(e)) span.record_exception(e) raise e - + + async def add_episode_bulk( self, diff --git a/server/graph_service/dto/__init__.py b/server/graph_service/dto/__init__.py index 2c047f695..7f770498d 100644 --- a/server/graph_service/dto/__init__.py +++ b/server/graph_service/dto/__init__.py @@ -1,5 +1,5 @@ from .common import Message, Result, Text, Actor, Author, Object, Procedure, Condition, Event, Location, Datetime, Year, Month, Day, Week, Concept, Image, Web, Category, ExecutedBy, ScheduledOn, PartOf, DependsOn, Triggers, LocatedIn, HasDetail, entity_type, edge_type, edge_type_maps -from .ingest import AddEntityNodeRequest, AddMessagesRequest, AddTextsRequest +from .ingest import AddEntityNodeRequest, AddMessagesRequest, AddTextsRequest, NormalizeNodeRequest, SaveEpisodeRequest from .retrieve import NodeResult, FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults __all__ = [ @@ -40,5 +40,7 @@ 'HasDetail', 'entity_type', 'edge_type', - 'edge_type_maps' + 'edge_type_maps', + 'NormalizeNodeRequest', + 'SaveEpisodeRequest' ] diff --git a/server/graph_service/dto/ingest.py b/server/graph_service/dto/ingest.py index 450690022..883e3f1b8 100644 --- a/server/graph_service/dto/ingest.py +++ b/server/graph_service/dto/ingest.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import List +from typing import Any, List from graph_service.dto.common import Message, Text class AddMessagesRequest(BaseModel): @@ -18,3 +18,16 @@ class AddTextsRequest(BaseModel): group_id: str = Field(..., description='The group id of the texts to add') texts: List[Text] = Field(..., description='The texts to add') prompt: str = Field(default='', description='The custom extraction prompt for the text') + +# Node -> Graphiti. +class NormalizeNodeRequest(BaseModel): + group_id: str = Field(..., description='The group id of the texts to add') + nodes: List[dict[str, Any]] = Field(..., description='The nodes of extracted by LLM') + +# Node -> Graphiti. +class SaveEpisodeRequest(BaseModel): + group_id: str = Field(..., description='The group id of the texts to add') + name: str = Field(..., description='The title of the episode to add') + content: str = Field(..., description='The content of the episode to add') + nodes: List[dict[str, Any]] = Field(..., description='The nodes of extracted by LLM') + edges: List[dict[str, Any]] = Field(..., description='The edges of extracted by LLM') \ No newline at end of file diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index 3f6a1bdc7..c33dfc3a1 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -4,9 +4,10 @@ from fastapi import APIRouter, FastAPI, status from graphiti_core.nodes import EpisodeType # type: ignore +from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore -from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result, AddTextsRequest, Text, edge_type_maps, edge_type, entity_type +from graph_service.dto import NormalizeNodeRequest, SaveEpisodeRequest, AddEntityNodeRequest, AddMessagesRequest, Message, Result, AddTextsRequest, Text, edge_type_maps, edge_type, entity_type from graph_service.zep_graphiti import ZepGraphitiDep @@ -83,6 +84,41 @@ async def build_task(): return Result(message='Community building started in background', success=True) +@router.post('/normalize-node', status_code=status.HTTP_200_OK) +async def normalize_node(request: NormalizeNodeRequest, graphiti: ZepGraphitiDep): + normalized_node = await graphiti.normalize_node_v2(group_id=request.group_id, extracted_nodes=request.nodes) + normalized= [ + { + "name": node.name, + "labels": node.labels, + "summary": node.summary, + "attribute": node.attributes, + "nameEmbed": node.name_embedding + } for node in normalized_node.nodes] + + return normalized + + + +@router.post('/save-episode', status_code=status.HTTP_202_ACCEPTED) +async def save_episode( + request: SaveEpisodeRequest, + graphiti: ZepGraphitiDep, +): + save_episode = await graphiti.add_episode_v2( + group_id=request.group_id, + name=request.name, + episode_body=request.content, # specific strategy content + source_description=request.name, + reference_time=utc_now(), + extract_nodes=request.nodes, + extract_edges=request.edges, + source=EpisodeType.text, + ) + + return save_episode + + @router.post('/texts', status_code=status.HTTP_202_ACCEPTED) async def add_texts( request: AddTextsRequest, From b2b24346841da8d254cc564c9b102c481e9cf1d7 Mon Sep 17 00:00:00 2001 From: lina-lemon Date: Thu, 12 Feb 2026 10:14:51 +0900 Subject: [PATCH 9/9] feat: docker file to graph rag setup --- Dockerfile | 81 ++++---------------------------------- docker-compose.yml | 97 +++++++++++++--------------------------------- 2 files changed, 33 insertions(+), 145 deletions(-) diff --git a/Dockerfile b/Dockerfile index b07e53a00..7c640c761 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,78 +1,11 @@ -# syntax=docker/dockerfile:1.9 -FROM python:3.12-slim +FROM zepai/graphiti:latest -# Inherit build arguments for labels -ARG GRAPHITI_VERSION -ARG BUILD_DATE -ARG VCS_REF +USER root -# OCI image annotations -LABEL org.opencontainers.image.title="Graphiti FastAPI Server" -LABEL org.opencontainers.image.description="FastAPI server for Graphiti temporal knowledge graphs" -LABEL org.opencontainers.image.version="${GRAPHITI_VERSION}" -LABEL org.opencontainers.image.created="${BUILD_DATE}" -LABEL org.opencontainers.image.revision="${VCS_REF}" -LABEL org.opencontainers.image.vendor="Zep AI" -LABEL org.opencontainers.image.source="https://github.com/getzep/graphiti" -LABEL org.opencontainers.image.documentation="https://github.com/getzep/graphiti/tree/main/server" -LABEL io.graphiti.core.version="${GRAPHITI_VERSION}" +# scikit-learn 설치 (uv 활용) +RUN uv pip install --no-cache-dir scikit-learn -# Install uv using the installer script -RUN apt-get update && apt-get install -y --no-install-recommends \ - curl \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* +# 로그 저장을 위한 디렉토리 생성 및 권한 설정 +RUN mkdir -p /app/logs && chown -p app:app /app/logs -ADD https://astral.sh/uv/install.sh /uv-installer.sh -RUN sh /uv-installer.sh && rm /uv-installer.sh -ENV PATH="/root/.local/bin:$PATH" - -# Configure uv for runtime -ENV UV_COMPILE_BYTECODE=1 \ - UV_LINK_MODE=copy \ - UV_PYTHON_DOWNLOADS=never - -# Create non-root user -RUN groupadd -r app && useradd -r -d /app -g app app - -# Set up the server application first -WORKDIR /app -COPY ./server/pyproject.toml ./server/README.md ./server/uv.lock ./ -COPY ./server/graph_service ./graph_service - -# Install server dependencies (without graphiti-core from lockfile) -# Then install graphiti-core from PyPI at the desired version -# This prevents the stale lockfile from pinning an old graphiti-core version -ARG INSTALL_FALKORDB=false -RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --frozen --no-dev && \ - if [ -n "$GRAPHITI_VERSION" ]; then \ - if [ "$INSTALL_FALKORDB" = "true" ]; then \ - uv pip install --system --upgrade "graphiti-core[falkordb]==$GRAPHITI_VERSION"; \ - else \ - uv pip install --system --upgrade "graphiti-core==$GRAPHITI_VERSION"; \ - fi; \ - else \ - if [ "$INSTALL_FALKORDB" = "true" ]; then \ - uv pip install --system --upgrade "graphiti-core[falkordb]"; \ - else \ - uv pip install --system --upgrade graphiti-core; \ - fi; \ - fi - -# Change ownership to app user -RUN chown -R app:app /app - -# Set environment variables -ENV PYTHONUNBUFFERED=1 \ - PATH="/app/.venv/bin:$PATH" - -# Switch to non-root user -USER app - -# Set port -ENV PORT=8000 -EXPOSE $PORT - -# Use uv run for execution -CMD ["uv", "run", "uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"] +USER app \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 1b5ba06df..ae2258759 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,32 +1,24 @@ +version: '3.8' + services: - graph: - profiles: [""] - build: - context: . - ports: - - "8000:8000" - healthcheck: - test: - [ - "CMD", - "python", - "-c", - "import urllib.request; urllib.request.urlopen('http://localhost:8000/healthcheck')", - ] - interval: 10s - timeout: 5s - retries: 3 - depends_on: - neo4j: - condition: service_healthy - environment: - - OPENAI_API_KEY=${OPENAI_API_KEY} - - NEO4J_URI=bolt://neo4j:${NEO4J_PORT:-7687} - - NEO4J_USER=${NEO4J_USER:-neo4j} - - NEO4J_PASSWORD=${NEO4J_PASSWORD:-password} - - PORT=8000 - - db_backend=neo4j - neo4j: + graph: + build: . + ports: + - "8000:8000" + volumes: + # 1. 소스 코드 마운트 (실시간 반영) + - ./server/graph_service:/app/graph_service + # 2. 라이브러리 코드 주입 (개발용) + - ./graphiti_core:/app/.venv/lib/python3.12/site-packages/graphiti_core + # 3. 로그 저장 + - ./logs/graph:/app/logs + environment: + # - OPENAI_API_KEY=${OPENAI_API_KEY} + - NEO4J_URI=bolt://neo4j:7687 + - NEO4J_USER=${NEO4J_USER} + - NEO4J_PASSWORD=${NEO4J_PASSWORD} + + neo4j: image: neo4j:5.26.2 profiles: [""] healthcheck: @@ -44,49 +36,12 @@ services: - "${NEO4J_PORT:-7687}:${NEO4J_PORT:-7687}" # Bolt volumes: - neo4j_data:/data + - ./logs/neo4j:/logs environment: - - NEO4J_AUTH=${NEO4J_USER:-neo4j}/${NEO4J_PASSWORD:-password} - - falkordb: - image: falkordb/falkordb:latest - profiles: ["falkordb"] - ports: - - "6379:6379" - volumes: - - falkordb_data:/data - environment: - - FALKORDB_ARGS=--port 6379 --cluster-enabled no - healthcheck: - test: ["CMD", "redis-cli", "-p", "6379", "ping"] - interval: 1s - timeout: 10s - retries: 10 - start_period: 3s - graph-falkordb: - build: - args: - INSTALL_FALKORDB: "true" - context: . - profiles: ["falkordb"] - ports: - - "8001:8001" - depends_on: - falkordb: - condition: service_healthy - healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8001/healthcheck')"] - interval: 10s - timeout: 5s - retries: 3 - environment: - - OPENAI_API_KEY=${OPENAI_API_KEY} - - FALKORDB_HOST=falkordb - - FALKORDB_PORT=6379 - - FALKORDB_DATABASE=default_db - - GRAPHITI_BACKEND=falkordb - - PORT=8001 - - db_backend=falkordb + - NEO4J_AUTH=${NEO4J_USER:-neo4j}/${NEO4J_PASSWORD:-user-dev} + - NEO4J_server_memory_heap_initial__size=512m + - NEO4J_server_memory_heap_max__size=1G + - NEO4J_server_memory_pagecache_size=512m volumes: - neo4j_data: - falkordb_data: + neo4j_data: \ No newline at end of file