Skip to content

Commit b685952

Browse files
authored
Merge pull request #9 from 2025-Graduation-Design/new
feat : 일기 임베딩 관련 사항 수정
2 parents a03ac42 + 84593a0 commit b685952

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

app/diary/router.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,19 @@ async def create_diary_with_music_recommend_top3(
214214
# 8) 유사도 계산 및 Top-3 추천곡 선정
215215
heap = []
216216
counter = 0
217-
for song in songs:
218-
try:
219-
song_id = int(song["id"])
220-
cache_key = f"lyrics_emb:{song_id}"
221-
cached = await redis.get(cache_key)
222217

218+
# 1) 곡 ID 추출
219+
song_id_map = {int(song["id"]): song for song in songs}
220+
song_ids = list(song_id_map.keys())
221+
cache_keys = [f"lyrics_emb:{song_id}" for song_id in song_ids]
222+
223+
# 2) Redis 일괄 조회
224+
cached_values = await redis.mget(cache_keys)
225+
226+
# 3) 유사도 계산
227+
combined_np = np.array(combined_embedding) # (768,)
228+
for song_id, cached in zip(song_ids, cached_values):
229+
try:
223230
if cached:
224231
lyrics_embedding = np.array(json.loads(cached))
225232
else:
@@ -230,27 +237,29 @@ async def create_diary_with_music_recommend_top3(
230237
if not result:
231238
continue
232239
lyrics_embedding = np.array(json.loads(result[0]))
233-
await redis.set(cache_key, json.dumps(lyrics_embedding.tolist()))
240+
await redis.set(f"lyrics_emb:{song_id}", json.dumps(lyrics_embedding.tolist()), ex=60*60*24*30)
234241

235242
if len(lyrics_embedding.shape) != 2:
236243
continue
237244

245+
song = song_id_map[song_id]
238246
lyrics = song.get("lyrics", [])
239247
if len(lyrics) < 1 or len(lyrics_embedding) != len(lyrics):
240248
continue
241249

242-
for idx, block_emb in enumerate(lyrics_embedding):
243-
similarity = F.cosine_similarity(
244-
torch.tensor(combined_embedding).unsqueeze(0),
245-
torch.tensor(block_emb).unsqueeze(0)
246-
).item()
250+
# 4) 전체 블럭과 유사도 한 번에 계산
251+
dot = np.dot(lyrics_embedding, combined_np) # (n,)
252+
norm_block = np.linalg.norm(lyrics_embedding, axis=1)
253+
norm_query = np.linalg.norm(combined_np)
254+
similarities = dot / (norm_block * norm_query + 1e-8)
247255

256+
for idx, similarity in enumerate(similarities):
248257
heapq.heappush(heap, (
249258
similarity,
250259
counter,
251260
{
252-
"song_id": song["id"],
253-
"lyric_chunk": [lyrics[idx]], # 블럭 하나만
261+
"song_id": song_id,
262+
"lyric_chunk": [lyrics[idx]],
254263
"similarity": similarity,
255264
"metadata": {
256265
"song_name": song.get("song_name"),

app/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from fastapi import FastAPI
22

3-
from app.user.router import router as user_router# 🐾 유저 라우터 임포트
3+
from app.user.router import router as user_router
44
from app.diary.router import router as diary_router
55
from app.genre.router import router as genre_router
66
from app.crawling.router import router as crawling_router

0 commit comments

Comments
 (0)