Skip to content

Commit 8c81a2e

Browse files
authored
Merge pull request #201 from capgoing/perf/#200
⚡️ [Perf] GraphRAG 성능 개선 - 컨텍스트가 더 잘 추출되도록 수정
2 parents 04a32f7 + fc8635a commit 8c81a2e

File tree

7 files changed

+125
-68
lines changed

7 files changed

+125
-68
lines changed

src/main/java/com/going/server/domain/chatbot/dto/CreateChatbotResponseDto.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,25 @@ public class CreateChatbotResponseDto {
1515
private String graphId;
1616
@JsonFormat(pattern = "yyyy-MM-dd'T'HH:mm")
1717
private LocalDateTime createdAt;
18+
private List<String> contextChunks; //LLM에 넘긴 context 문장들 (이름은 `augmentedSentences` 등으로 변경 권장)
1819
private List<String> retrievedTriples; //관계 중심의 3요소 표현 ("물 -상태변화→ 응고")
1920
private List<String> sourceNodes; //질의에 사용된 핵심 노드들 ("물", "응고" 등)
20-
private List<String> 증강할때쓴자료; //LLM에 넘긴 context 문장들 (이름은 `augmentedSentences` 등으로 변경 권장)
21-
22-
private Map<String, String> ragMeta; //(ex: 사용한 쿼리문 등)
2321

2422
public static CreateChatbotResponseDto of(
2523
String chatContent,
2624
String graphId,
2725
LocalDateTime createdAt,
28-
List<String> retrievedChunks,
26+
List<String> contextChunks,
27+
List<String> retrievedTriples,
2928
List<String> sourceNodes
3029
) {
3130
return CreateChatbotResponseDto.builder()
3231
.chatContent(chatContent)
3332
.graphId(graphId)
3433
.createdAt(createdAt)
35-
.retrievedTriples(retrievedChunks)
34+
.contextChunks(contextChunks)
35+
.retrievedTriples(retrievedTriples)
3636
.sourceNodes(sourceNodes)
37-
.ragMeta(Map.of("chunkCount", String.valueOf(retrievedChunks.size())))
3837
.build();
3938
}
4039
}

src/main/java/com/going/server/domain/rag/dto/GraphQueryResult.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@
66
@Getter
77
@AllArgsConstructor
88
public class GraphQueryResult {
9-
private String sentence;
10-
private String nodeLabel;
11-
}
9+
private String sentence; // 예: "물은 응고되어 얼음이 된다."
10+
private String sourceLabel; // 예: "물"
11+
private String relationLabel; // 예: "응고"
12+
private String targetLabel; // 예: "얼음"
13+
private String nodeLabel; // 예: "물" (질의어에 가까운 노드)
14+
15+
public String toTripleString() {
16+
if (sourceLabel == null || relationLabel == null || targetLabel == null) return null;
17+
18+
// 혹시 내부 문자열이 "null"로 들어오는 것도 막기
19+
if ("null".equals(sourceLabel) || "null".equals(relationLabel) || "null".equals(targetLabel)) return null;
20+
21+
return String.format("(%s)-[:RELATED {label: '%s'}]->(%s)", sourceLabel, relationLabel, targetLabel);
22+
}
23+
24+
}
Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,49 @@
11
package com.going.server.domain.rag.service;
22

33
import com.going.server.domain.openai.service.OpenAIService;
4-
import com.theokanning.openai.OpenAiService;
54
import com.theokanning.openai.completion.chat.ChatMessage;
65
import lombok.RequiredArgsConstructor;
76
import org.springframework.stereotype.Component;
87

98
import java.util.List;
109

11-
// 1. 질문 → Cypher 쿼리 생성 (LLM)
1210
@Component
1311
@RequiredArgsConstructor
1412
public class CypherQueryGenerator {
1513
private final OpenAIService openAIService;
1614

1715
public String generate(String userQuestion) {
1816
String prompt = """
19-
당신은 Neo4j용 Cypher 쿼리를 생성하는 AI입니다.
20-
주어진 질문에 대해 Cypher 쿼리만 반환하세요. 코드블록, 설명 없이 오직 쿼리만 출력해야 합니다.
17+
당신은 Neo4j 그래프 데이터베이스에서 정보를 추출하기 위한 Cypher 쿼리를 생성하는 AI입니다.
2118
22-
예:
23-
질문: "고래와 관련된 개념들을 알려줘"
24-
→ MATCH (n:GraphNode)-[r]->(m:GraphNode)\s
25-
WHERE n.label = '고래'\s
26-
RETURN m.label AS nodeLabel, m.includeSentence AS sentence\s
27-
LIMIT 10
28-
29-
질문: "${userQuestion}"
19+
- 주어진 질문에서 핵심 개념과 연관된 개념들을 찾아야 합니다.
20+
- 질문에 포함된 키워드와 의미적으로 밀접한 노드 쌍 간의 관계(triple)를 추출해야 합니다.
21+
- 반드시 관계 중심 구조 (시작 노드, 관계 라벨, 도착 노드)를 반환하는 Cypher 쿼리를 작성하세요.
22+
- 관계나 노드에 포함된 설명 문장 중 하나를 함께 반환하세요. (r.sentence → 없으면 a.includeSentence → 없으면 b.includeSentence 순으로)
23+
- 반환 항목은 다음과 같아야 합니다:
24+
sourceLabel, relationLabel, targetLabel, sentence, nodeLabel
25+
- 코드는 반드시 Cypher 쿼리 한 줄만 출력하며, 코드블록이나 설명은 포함하지 마세요.
26+
27+
예시:
28+
질문: "고래와 관련된 개념들을 알려줘"
29+
30+
MATCH (a:GraphNode)-[r:RELATED]-(b:GraphNode)
31+
WHERE toLower(a.label) CONTAINS toLower('고래') OR toLower(b.label) CONTAINS toLower('고래')
32+
RETURN
33+
a.label AS sourceLabel,
34+
r.label AS relationLabel,
35+
b.label AS targetLabel,
36+
COALESCE(r.sentence, a.includeSentence, b.includeSentence, "") AS sentence,
37+
a.label AS nodeLabel
38+
LIMIT 15
39+
40+
질문: "%s"
3041
31-
""".formatted(userQuestion);
42+
""".formatted(userQuestion);
3243

3344
return openAIService.getCompletionResponse(
3445
List.of(new ChatMessage("user", prompt)),
35-
"gpt-4-0125-preview", 0.2, 1000
46+
"gpt-4o", 0.2, 500
3647
);
3748
}
38-
}
49+
}

src/main/java/com/going/server/domain/rag/service/GraphQueryExecutor.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,45 @@
1414
@RequiredArgsConstructor
1515
public class GraphQueryExecutor {
1616

17-
private final Driver neo4jDriver; // Neo4j Java Driver
17+
private final Driver neo4jDriver;
1818

1919
public List<GraphQueryResult> runQuery(Long graphId, String cypherQuery) {
2020
List<GraphQueryResult> results = new ArrayList<>();
2121

2222
try (Session session = neo4jDriver.session()) {
2323
Result result = session.run(cypherQuery);
24+
2425
while (result.hasNext()) {
2526
Record record = result.next();
2627

27-
// 필드 이름은 Cypher 쿼리 결과와 일치해야 함
28-
String sentence = record.get("sentence").asString("");
29-
String nodeLabel = record.get("nodeLabel").asString("");
28+
String sentence = getSafeString(record, "sentence");
29+
String nodeLabel = getSafeString(record, "nodeLabel");
30+
31+
String sourceLabel = getSafeString(record, "sourceLabel");
32+
String relationLabel = getSafeString(record, "relationLabel");
33+
String targetLabel = getSafeString(record, "targetLabel");
3034

31-
results.add(new GraphQueryResult(sentence, nodeLabel));
35+
results.add(new GraphQueryResult(
36+
sentence,
37+
nodeLabel,
38+
sourceLabel,
39+
relationLabel,
40+
targetLabel
41+
));
3242
}
43+
3344
} catch (Exception e) {
45+
System.err.println("[GraphRAG] Cypher 쿼리 실행 중 오류 발생:");
3446
e.printStackTrace();
3547
}
3648

3749
return results;
3850
}
51+
52+
// 안전한 String 추출 (null-safe)
53+
private String getSafeString(Record record, String key) {
54+
return record.containsKey(key) && !record.get(key).isNull()
55+
? record.get(key).asString()
56+
: null;
57+
}
3958
}

src/main/java/com/going/server/domain/rag/service/GraphRAGService.java

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.springframework.stereotype.Service;
1414

1515
import java.util.List;
16+
import java.util.regex.Matcher;
17+
import java.util.regex.Pattern;
1618

1719
@Service
1820
@RequiredArgsConstructor
@@ -30,13 +32,7 @@ public class GraphRAGService {
3032

3133
/**
3234
* 사용자 질문에 대해 Cypher 쿼리 → 그래프 정보 검색 → 프롬프트 생성 → LLM 응답 생성
33-
* 본 메서드는 LangChain 없이 구현한 Spring 기반 GraphRAG의 핵심 흐름입니다.
34-
*
35-
* private LocalDateTime createdAt;
36-
* private List<String> retrievedTriples; //관계 중심의 3요소 표현 ("물 -상태변화→ 응고")
37-
* private List<String> sourceNodes; //질의에 사용된 핵심 노드들 ("물", "응고" 등)
38-
* private List<String> 증강할때쓴자료; //LLM에 넘긴 context 문장들 (이름은 `augmentedSentences` 등으로 변경 권장)
39-
* -> 이렇게 결과 나오도록 정리
35+
* LangChain 없이 구현한 Spring 기반 GraphRAG의 핵심 흐름
4036
*/
4137
public CreateChatbotResponseDto createAnswerWithGraphRAG(
4238
Long dbId,
@@ -47,52 +43,60 @@ public CreateChatbotResponseDto createAnswerWithGraphRAG(
4743
log.info("[GraphRAG] dbId: {}, question: {}", dbId, userQuestion);
4844

4945
// 1. 질문 → Cypher 쿼리 생성
50-
String cypherQuery = cypherQueryGenerator.generate(userQuestion).trim()
51-
.replaceAll("(?s)```cypher.*?```", "") // 마크다운 제거
52-
.replaceAll("```", "") // 남은 ``` 제거
53-
.trim();
54-
log.info("[GraphRAG] Generated Cypher Query:\n{}", cypherQuery);
46+
String rawQuery = cypherQueryGenerator.generate(userQuestion);
47+
// ```cypher ~ ``` 블록 제거
48+
Matcher m = Pattern.compile("(?s)```cypher\\s*(.*?)\\s*```").matcher(rawQuery);
49+
String cleaned = m.find() ? m.group(1) : rawQuery;
50+
// 남은 ``` 제거
51+
cleaned = cleaned.replaceAll("```", "").trim();
52+
log.info("[GraphRAG] Cypher Query 생성됨:\n----\n{}\n----", cleaned);
5553

5654
// 2. 쿼리 실행 → 문맥(context) 및 노드 라벨 추출
57-
List<GraphQueryResult> queryResults = graphQueryExecutor.runQuery(dbId, cypherQuery);
55+
List<GraphQueryResult> queryResults = graphQueryExecutor.runQuery(dbId, cleaned);
56+
// 문장
5857
List<String> contextChunks = queryResults.stream()
5958
.map(GraphQueryResult::getSentence)
59+
.filter(s -> s != null && !s.isBlank())
60+
.distinct()
6061
.toList();
61-
62+
// 관계 트리플
63+
List<String> retrievedTriples = queryResults.stream()
64+
.map(GraphQueryResult::toTripleString)
65+
.distinct()
66+
.toList();
67+
// 노드
6268
List<String> sourceNodes = queryResults.stream()
6369
.map(GraphQueryResult::getNodeLabel)
70+
.filter(n -> n != null && !n.isBlank())
6471
.distinct()
6572
.toList();
73+
6674
log.info("[GraphRAG] Retrieved {} context chunks", contextChunks.size());
75+
retrievedTriples.forEach(triple ->
76+
log.info("[GraphRAG] Triple: {}", triple)
77+
);
78+
log.info("[GraphRAG] Retrieved {} triples", retrievedTriples.size());
6779

6880
// 3. 프롬프트 구성
69-
String finalPrompt = promptBuilder.buildPrompt(contextChunks, userQuestion);
81+
String finalPrompt = promptBuilder.buildPrompt(contextChunks, retrievedTriples, userQuestion);
7082
log.info("[GraphRAG] Final Prompt constructed");
7183

7284
// 4. RAG 응답 생성
73-
String response = contextChunks.isEmpty()
74-
? ragAnswerCreateService.chat(chatHistory, userQuestion)
75-
: ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt);
85+
boolean hasContext = !contextChunks.isEmpty() || !retrievedTriples.isEmpty();
86+
String response = hasContext
87+
? ragAnswerCreateService.chatWithContext(chatHistory, finalPrompt)
88+
: ragAnswerCreateService.chat(chatHistory, userQuestion);
7689
log.info("[GraphRAG] Response generated by LLM");
7790

7891
// 5. 응답 저장
7992
Chatting answer = Chatting.ofGPT(graph, response);
8093
chattingRepository.save(answer);
81-
log.info("[GraphRAG] Response saved to DB");
82-
83-
// 임시 retrievedTriples 설정
84-
List<String> retrievedTriples = List.of(
85-
"(물)-[:RELATED {label: '상태변화'}]->(기화)",
86-
"(기화)-[:RELATED {label: '조건'}]->(높은 온도)",
87-
"(수증기)-[:RELATED {label: '응결'}]->(물방울)",
88-
"(물)-[:RELATED {label: '응고'}]->(얼음)",
89-
"(응고)-[:RELATED {label: '예시'}]->(겨울철 얼어붙은 길)"
90-
);
9194

9295
return CreateChatbotResponseDto.of(
9396
response,
9497
dbId.toString(),
9598
answer.getCreatedAt(),
99+
contextChunks,
96100
retrievedTriples,
97101
sourceNodes
98102
);

src/main/java/com/going/server/domain/rag/service/RagAnswerCreateService.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@ public class RagAnswerCreateService {
1818
private final OpenAIService openAIService;
1919

2020
private static final String SYSTEM_PROMPT = """
21-
당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다.
22-
- 아래 제공된 데이터를 기반으로 질문에 대해 매우 길고 정확하게 설명해주세요.
23-
- 만약 참고 데이터가 없다면, 관련정보 없다고 하세요.
24-
- 반드시 한글로만 응답하고, 인사말이나 불필요한 문장은 생략한 대답만 반환하세요.
25-
""";
21+
당신은 초등학생의 이해를 돕는 친절하고 정확한 지식 튜터입니다.
22+
23+
- 아래에 제공된 '관계 정보'와 '설명 문장'은 질문과 관련된 지식그래프에서 추출된 정보입니다.
24+
- 반드시 이 정보를 바탕으로 질문에 대해 정확하고 구체적으로 설명해주세요.
25+
- 관계 간의 연결 흐름이나 개념 간 연관성을 쉽게 풀어 설명해 주세요.
26+
- 필요 이상으로 친절하거나 장황하게 말하지 말고, 정확하고 알기 쉽게 대답만 하세요.
27+
- 대답은 반드시 한글로만 작성하고, 인사말이나 부가 설명 없이 본문만 반환하세요.
28+
""";
2629

2730
private static final String MODEL_NAME = "gpt-4o";
2831
private static final double TEMPERATURE = 0.3;
29-
private static final int MAX_TOKENS = 1500;
32+
private static final int MAX_TOKENS = 1200;
3033

3134
public String chat(List<Chatting> chatHistory, String question) {
3235
List<ChatMessage> messages = new ArrayList<>();

src/main/java/com/going/server/domain/rag/util/PromptBuilder.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,25 @@
77
@Component
88
public class PromptBuilder {
99

10-
public String buildPrompt(List<String> chunks, String question) {
10+
public String buildPrompt(List<String> contextChunks, List<String> triples, String userQuestion) {
1111
StringBuilder sb = new StringBuilder();
1212

1313
sb.append("다음 정보를 참고하여 질문에 답해주세요.\n\n");
14-
sb.append("[관련 정보]\n");
1514

16-
for (String chunk : chunks) {
17-
sb.append("- ").append(chunk.trim()).append("\n");
15+
if (!triples.isEmpty()) {
16+
sb.append("[관계 정보]\n");
17+
triples.forEach(triple -> sb.append("- ").append(triple).append("\n"));
18+
sb.append("\n");
19+
}
20+
if (!contextChunks.isEmpty()) {
21+
sb.append("[설명 문장]\n");
22+
for (int i = 0; i < contextChunks.size(); i++) {
23+
sb.append(i + 1).append(". ").append(contextChunks.get(i)).append("\n");
24+
}
25+
sb.append("\n");
1826
}
1927

20-
sb.append("\n[질문]\n").append(question.trim()).append("\n\n");
28+
sb.append("질문: ").append(userQuestion);
2129
sb.append("[답변]\n");
2230

2331
return sb.toString();

0 commit comments

Comments
 (0)