Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions app/api/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,70 @@
from pydantic import BaseModel, model_validator
from typing import Dict, Self, Optional, Any
from pydantic import BaseModel, model_validator, Field
from typing import Dict, Self, Optional, Any, List
import ulid

class ChatRequest(BaseModel):
session_id: str=""
session_id: str = ""
message: str
model_params: Optional[Dict[str, Any]] = None

@model_validator(mode="after")
def set_session_id(self)->Self:
def set_session_id(self) -> Self:
if not self.session_id:
self.session_id = ulid.ulid()
return self
return self


# --- Branch Listing ---
class BranchListResponse(BaseModel):
branches: List[str] = Field(..., description="List of branch names in the repository.")


# --- Valid Target Branches ---
class ValidTargetBranchesRequest(BaseModel):
session_id: str = Field(..., description="Session identifier.")
repo: str = Field(..., description="Repository name.")
source_branch: str = Field(..., description="Source branch name.")

class ValidTargetBranchesResponse(BaseModel):
valid_target_branches: List[str] = Field(..., description="List of valid target branch names.")


# --- Pull Request Creation ---
class CreatePullRequestRequest(BaseModel):
session_id: str = Field(..., description="Session identifier.")
repo: str = Field(..., description="Repository name.")
source_branch: str = Field(..., description="Source branch name.")
target_branch: str = Field(..., description="Target branch name.")
title: Optional[str] = Field(None, description="Title of the pull request.")
description: str = Field(..., description="Description/body of the pull request. This field is required.")
draft: Optional[bool] = Field(False, description="Whether to create the PR as a draft.")
reviewers: Optional[List[str]] = Field(None, description="List of reviewer usernames.")
assignees: Optional[List[str]] = Field(None, description="List of assignee usernames.")
labels: Optional[List[str]] = Field(None, description="List of label names.")

# --- Pull Request Diff ---
class GetPullRequestDiffRequest(BaseModel):
session_id: str = Field(..., description="Session identifier.")
repo: str = Field(..., description="Repository name.")
source_branch: str = Field(..., description="Source branch name.")
target_branch: str = Field(..., description="Target branch name.")

class GetPullRequestDiffResponse(BaseModel):
commits: List[dict] = Field(..., description="List of commit dicts in the diff.")

class CreatePullRequestResponse(BaseModel):
url: str = Field(..., description="URL of the created pull request.")
number: int = Field(..., description="Pull request number.")
state: str = Field(..., description="State of the pull request (e.g., open, closed).")
success: bool = Field(..., description="Whether the pull request was created successfully.")
# Optionally, include the generated description if LLM was used
generated_description: Optional[str] = Field(None, description="LLM-generated PR description, if applicable.")


# --- Utility: Commit List for PR Description Generation ---
class CommitMessagesForPRDescriptionRequest(BaseModel):
commit_messages: List[str] = Field(..., description="List of commit messages to summarize.")
session_id: str = Field(..., description="Session identifier.")

class PRDescriptionResponse(BaseModel):
description: str = Field(..., description="LLM-generated pull request description.")
106 changes: 95 additions & 11 deletions app/api/server/routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from fastapi import APIRouter, HTTPException, Request, Query
from pydantic import BaseModel

from models.schemas import (
BranchListResponse,
ValidTargetBranchesRequest,
ValidTargetBranchesResponse,
CreatePullRequestRequest,
CreatePullRequestResponse,
)

from models.schemas import GetPullRequestDiffRequest, GetPullRequestDiffResponse
from services.llm_service import set_llm, get_llm, trim_messages
from services.fetcher_service import store_fetcher, get_fetcher
from git_recap.utils import parse_entries_to_txt, parse_releases_to_txt
Expand Down Expand Up @@ -214,7 +223,7 @@ async def get_release_notes(
# Get fetcher for session
try:
fetcher = get_fetcher(session_id)
except HTTPException as e:
except HTTPException:
raise

# Check if fetcher supports fetch_releases
Expand Down Expand Up @@ -273,13 +282,88 @@ async def get_release_notes(

return {"actions": "\n\n".join([actions_txt, releases_txt])}

# @router.post("/chat")
# async def chat(
# chat_request: ChatRequest
# ):
# try:
# llm = await initialize_llm_session(chat_request.session_id)
# response = await llm.acomplete(chat_request.message)
# return {"response": response}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
# --- Branch and Pull Request Management Endpoints ---
@router.get("/branches", response_model=BranchListResponse)
async def get_branches(
session_id: str,
repo: str
):
"""
Get all branches for a given repository in the current session.
"""
fetcher = get_fetcher(session_id)
try:
fetcher.repo_filter = [repo]
branches = fetcher.get_branches()
except NotImplementedError:
raise HTTPException(status_code=400, detail="Branch listing is not supported for this provider.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch branches: {str(e)}")
return BranchListResponse(branches=branches)

@router.post("/valid-target-branches", response_model=ValidTargetBranchesResponse)
async def get_valid_target_branches(
req: ValidTargetBranchesRequest
):
"""
Get all valid target branches for a given source branch in a repository.
"""
fetcher = get_fetcher(req.session_id)
try:
fetcher.repo_filter = [req.repo]
valid_targets = fetcher.get_valid_target_branches(req.source_branch)
except NotImplementedError:
raise HTTPException(status_code=400, detail="Target branch validation is not supported for this provider.")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to validate target branches: {str(e)}")
return ValidTargetBranchesResponse(valid_target_branches=valid_targets)

@router.post("/create-pull-request", response_model=CreatePullRequestResponse)
async def create_pull_request(
req: CreatePullRequestRequest
):
fetcher = get_fetcher(req.session_id)
fetcher.repo_filter = [req.repo]
if not req.description or not req.description.strip():
raise HTTPException(status_code=400, detail="Description is required for pull request creation.")
try:
result = fetcher.create_pull_request(
head_branch=req.source_branch,
base_branch=req.target_branch,
title=req.title or f"Merge {req.source_branch} into {req.target_branch}",
body=req.description,
draft=req.draft or False,
reviewers=req.reviewers,
assignees=req.assignees,
labels=req.labels,
)
except NotImplementedError:
raise HTTPException(status_code=400, detail="Pull request creation is not supported for this provider.")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create pull request: {str(e)}")
return CreatePullRequestResponse(
url=result.get("url"),
number=result.get("number"),
state=result.get("state"),
success=result.get("success", False),
generated_description=None
)

@router.post("/get-pull-request-diff", response_model=GetPullRequestDiffResponse)
async def get_pull_request_diff(req: GetPullRequestDiffRequest):
fetcher = get_fetcher(req.session_id)
fetcher.repo_filter = [req.repo]
provider = type(fetcher).__name__.lower()
if "github" not in provider:
raise HTTPException(status_code=400, detail="Pull request diff is only supported for GitHub provider.")
try:
commits = fetcher.fetch_branch_diff_commits(req.source_branch, req.target_branch)
except NotImplementedError:
raise HTTPException(status_code=400, detail="Branch diff is not supported for this provider.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch pull request diff: {str(e)}")
return GetPullRequestDiffResponse(commits=commits)
112 changes: 89 additions & 23 deletions app/api/server/websockets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
import json
from typing import Optional
from typing import Literal, Optional
import asyncio

from services.prompts import SELECT_QUIRKY_REMARK_SYSTEM, SYSTEM, RELEASE_NOTES_SYSTEM, quirky_remarks
from services.llm_service import get_random_quirky_remarks, run_concurrent_tasks, get_llm
from services.prompts import (
PR_DESCRIPTION_SYSTEM,
SELECT_QUIRKY_REMARK_SYSTEM,
SYSTEM,
RELEASE_NOTES_SYSTEM,
quirky_remarks,
)
from services.llm_service import (
get_random_quirky_remarks,
run_concurrent_tasks,
get_llm,
)
from aicore.const import SPECIAL_TOKENS, STREAM_END_TOKEN
import asyncio

router = APIRouter()

Expand All @@ -26,54 +36,96 @@
{ACTIONS}
"""

TRIGGER_PULL_REQUEST_PROMPT = """
You will now receive a list of commit messages between two branches.
Using the system instructions provided above, generate a clear, concise, and professional **Pull Request Description** summarizing all changes.

Commits:
{COMMITS}

Please follow these steps:
1. Read and analyze the commit messages.
2. Identify and group related changes under appropriate markdown headers (e.g., Features, Bug Fixes, Improvements, Documentation, Tests).
3. Write a short **summary paragraph** explaining the overall purpose of this pull request.
4. Format the final output as a complete markdown-formatted PR description, ready to paste into GitHub.

Begin your response directly with the formatted PR description—no extra commentary or explanation.
"""


@router.websocket("/ws/{session_id}/{action_type}")
async def websocket_endpoint(
websocket: WebSocket,
websocket: WebSocket,
session_id: Optional[str] = None,
action_type: str="recap"
action_type: Literal["recap", "release", "pull_request"] = "recap"
):
"""
WebSocket endpoint for real-time LLM operations.

Handles three action types:
- recap: Generate commit summaries with quirky remarks
- release: Generate release notes based on git history
- pull_request: Generate PR descriptions from commit diffs

Args:
websocket: WebSocket connection instance
session_id: Session identifier for LLM and fetcher management
action_type: Type of operation to perform

Raises:
HTTPException: If action_type is invalid
"""
await websocket.accept()

# Select appropriate system prompt based on action type
if action_type == "recap":
QUIRKY_SYSTEM = SELECT_QUIRKY_REMARK_SYSTEM.format(
examples=json.dumps(get_random_quirky_remarks(quirky_remarks), indent=4)
)

system = [SYSTEM, QUIRKY_SYSTEM]

elif action_type == "release":
system = RELEASE_NOTES_SYSTEM

elif action_type == "pull_request":
system = PR_DESCRIPTION_SYSTEM
else:
raise HTTPException(status_code=404)
# Store the connection
raise HTTPException(status_code=404, detail="Invalid action type")

# Store the active WebSocket connection
active_connections[session_id] = websocket

# Initialize LLM
# Initialize LLM session
llm = get_llm(session_id)

try:
while True:
# Receive message from client
message = await websocket.receive_text()
msg_json = json.loads(message)
message = msg_json.get("actions")
message_content = msg_json.get("actions")
N = msg_json.get("n", 5)
assert int(N) <= 15
assert message

# Validate inputs
assert int(N) <= 15, "N must be <= 15"
assert message_content, "Message content is required"

# Build history/prompt based on action type
if action_type == "recap":
history = [
TRIGGER_PROMPT.format(
N=N,
ACTIONS=message
ACTIONS=message_content
)
]
elif action_type == "release":
history = [
TRIGGER_RELEASE_PROMPT.format(ACTIONS=message)
TRIGGER_RELEASE_PROMPT.format(ACTIONS=message_content)
]
elif action_type == "pull_request":
history = [
TRIGGER_PULL_REQUEST_PROMPT.format(COMMITS=message_content)
]

# Stream LLM response back to client
response = []
async for chunk in run_concurrent_tasks(
llm,
Expand All @@ -85,24 +137,38 @@ async def websocket_endpoint(
break
elif chunk in SPECIAL_TOKENS:
continue

await websocket.send_text(json.dumps({"chunk": chunk}))
response.append(chunk)


# Store response in history for potential follow-up
history.append("".join(response))

except WebSocketDisconnect:
# Clean up connection on disconnect
if session_id in active_connections:
del active_connections[session_id]
except AssertionError as e:
# Handle validation errors
if session_id in active_connections:
await websocket.send_text(json.dumps({"error": f"Validation error: {str(e)}"}))
del active_connections[session_id]
except Exception as e:
# Handle unexpected errors
if session_id in active_connections:
await websocket.send_text(json.dumps({"error": str(e)}))
del active_connections[session_id]


def close_websocket_connection(session_id: str):
"""
Clean up and close the active websocket connection associated with the given session_id.
Clean up and close the active WebSocket connection associated with the given session_id.

This function is called during session expiration to ensure proper cleanup
of WebSocket resources.

Args:
session_id: The session identifier whose WebSocket connection should be closed
"""
websocket = active_connections.pop(session_id, None)
if websocket:
asyncio.create_task(websocket.close())
asyncio.create_task(websocket.close())
Loading
Loading