Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 37 additions & 42 deletions app/load_data.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,60 @@
# app/load_data.py
import json
import os
import uuid
from glob import glob
from app.llm_client import LlamaEdgeClient
from app.vector_store import QdrantStore
import logging

def load_project_examples():
"""Load project examples into vector database"""
vector_store = QdrantStore()
llm_client = LlamaEdgeClient()

PROJECT_COLLECTION = "project_examples"
ERROR_COLLECTION = "error_examples"
PROJECT_DATA_PATH = "data/project_examples/*.json"
ERROR_DATA_PATH = "data/error_examples/*.json"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_examples(vector_store, llm_client, collection_name, file_pattern, text_key):
"""Load examples into vector database"""
# Ensure collections exist
vector_store.create_collection("project_examples")
vector_store.create_collection(collection_name)

example_files = glob("data/project_examples/*.json")
example_files = glob(file_pattern)

# Collect all embeddings and metadata first
embeddings = []
metadata = []

for file_path in example_files:
with open(file_path, 'r') as f:
example = json.load(f)

# Get embedding for query
embedding = llm_client.get_embeddings([example["query"]])[0]

# Store in vector DB with proper UUID
point_id = str(uuid.uuid4()) # Generate proper UUID

vector_store.upsert("project_examples",
[{"id": point_id, # Use UUID instead of filename
"vector": embedding,
"payload": example}])

print(f"Loaded project example: {example['query']}")
# Get embedding for query or error
try:
embedding = llm_client.get_embeddings([example[text_key]])[0]
embeddings.append(embedding)
metadata.append(example)
logger.info(f"Loaded {collection_name[:-1]} example: {example[text_key][:50]}...")
except Exception as e:
logger.error(f"Error loading {file_path}: {e}")

# Insert all documents in a single batch
if embeddings:
vector_store.insert_documents(collection_name, embeddings, metadata)

def load_project_examples():
"""Load project examples into vector database"""
vector_store = QdrantStore()
llm_client = LlamaEdgeClient()

load_examples(vector_store, llm_client, PROJECT_COLLECTION, PROJECT_DATA_PATH, "query")

def load_error_examples():
"""Load compiler error examples into vector database"""
vector_store = QdrantStore()
llm_client = LlamaEdgeClient()

# Ensure collections exist
vector_store.create_collection("error_examples")

error_files = glob("data/error_examples/*.json")

for file_path in error_files:
with open(file_path, 'r') as f:
example = json.load(f)

# Get embedding for error
embedding = llm_client.get_embeddings([example["error"]])[0]

# Store in vector DB with proper UUID
point_id = str(uuid.uuid4())

# Store in vector DB
vector_store.upsert("error_examples",
[{"id": point_id,
"vector": embedding,
"payload": example}])

print(f"Loaded error example: {example['error'][:50]}...")
load_examples(vector_store, llm_client, ERROR_COLLECTION, ERROR_DATA_PATH, "error")

if __name__ == "__main__":
load_project_examples()
Expand Down
97 changes: 82 additions & 15 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Optional
from dotenv import load_dotenv
import tempfile
import logging

# Load environment variables from .env file
load_dotenv()
Expand All @@ -21,16 +22,22 @@

app = FastAPI(title="Rust Project Generator API")

class AppConfig:
def __init__(self):
self.api_key = os.getenv("LLM_API_KEY", "")
self.skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
self.embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536"))
# Add other config values

config = AppConfig()

# Get API key from environment variable (make optional)
api_key = os.getenv("LLM_API_KEY", "")
api_key = config.api_key
# Only validate if not using local setup
if not api_key and not (os.getenv("LLM_API_BASE", "").startswith("http://localhost") or
os.getenv("LLM_API_BASE", "").startswith("http://host.docker.internal")):
raise ValueError("LLM_API_KEY environment variable not set")

# Get embedding size from environment variable
llm_embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536")) # Default to 1536 for compatibility

# Initialize components
llm_client = LlamaEdgeClient(api_key=api_key)
prompt_gen = PromptGenerator()
Expand All @@ -39,8 +46,8 @@

# Initialize vector store
try:
vector_store = QdrantStore(embedding_size=llm_embed_size)
if os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true":
vector_store = QdrantStore(embedding_size=config.embed_size)
if config.skip_vector_search != "true":
vector_store.create_collection("project_examples")
vector_store.create_collection("error_examples")

Expand Down Expand Up @@ -160,7 +167,21 @@ async def compile_rust(request: dict):

@app.post("/compile-and-fix")
async def compile_and_fix_rust(request: dict):
"""Endpoint to compile and fix Rust code"""
"""
Compile Rust code and automatically fix compilation errors.

Args:
request (dict): Dictionary containing:
- code (str): Multi-file Rust code with filename markers
- description (str): Project description
- max_attempts (int, optional): Maximum fix attempts (default: 10)

Returns:
JSONResponse: Result of compilation with fixed code if successful

Raises:
HTTPException: If required fields are missing or processing fails
"""
if "code" not in request or "description" not in request:
raise HTTPException(status_code=400, detail="Missing required fields")

Expand Down Expand Up @@ -225,7 +246,7 @@ async def compile_and_fix_rust(request: dict):

# Find similar errors in vector DB
similar_errors = []
if vector_store is not None and os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true":
if vector_store is not None and config.skip_vector_search != "true":
try:
# Find similar errors in vector DB
error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0]
Expand Down Expand Up @@ -309,7 +330,7 @@ async def handle_project_generation(

try:
# Skip vector search if environment variable is set
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
skip_vector_search = config.skip_vector_search

example_text = ""
if not skip_vector_search:
Expand Down Expand Up @@ -397,7 +418,7 @@ async def handle_project_generation(
error_context = compiler.extract_error_context(output)

# Skip vector search if environment variable is set
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
skip_vector_search = config.skip_vector_search
similar_errors = []

if not skip_vector_search:
Expand Down Expand Up @@ -468,9 +489,11 @@ async def handle_project_generation(
})
save_status(project_dir, status)

def save_status(project_dir: str, status: Dict):
"""Save project status to file"""
with open(f"{project_dir}/status.json", 'w') as f:
def save_status(project_dir, status):
"""Save project status to file with proper resource management"""
status_path = f"{project_dir}/status.json"
os.makedirs(os.path.dirname(status_path), exist_ok=True)
with open(status_path, 'w') as f:
json.dump(status, f)

@app.get("/project/{project_id}/files/{file_path:path}")
Expand Down Expand Up @@ -536,7 +559,7 @@ async def generate_project_sync(request: ProjectRequest):
similar_errors = []

# Skip vector search if environment variable is set
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
skip_vector_search = config.skip_vector_search

if not skip_vector_search:
try:
Expand Down Expand Up @@ -621,7 +644,7 @@ async def generate_project_sync(request: ProjectRequest):
error_context = compiler.extract_error_context(output)

# Skip vector search if environment variable is set
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
skip_vector_search = config.skip_vector_search

if not skip_vector_search:
try:
Expand Down Expand Up @@ -700,3 +723,47 @@ async def generate_project_sync(request: ProjectRequest):

except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating project: {str(e)}")

def find_similar_projects(description, vector_store, llm_client):
"""Find similar projects in vector store"""
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
example_text = ""

if not skip_vector_search:
try:
query_embedding = llm_client.get_embeddings([description])[0]
similar_projects = vector_store.search("project_examples", query_embedding, limit=1)
if similar_projects:
example_text = f"\nHere's a similar project you can use as reference:\n{similar_projects[0]['example']}"
except Exception as e:
logger.warning(f"Vector search error (non-critical): {e}")

return example_text

logger = logging.getLogger(__name__)

def extract_and_find_similar_errors(error_output, vector_store, llm_client):
"""Extract error context and find similar errors"""
error_context = compiler.extract_error_context(error_output)
similar_errors = []

if vector_store and not config.skip_vector_search:
try:
error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0]
similar_errors = vector_store.search("error_examples", error_embedding, limit=3)
except Exception as e:
logger.warning(f"Vector search error: {e}")

return error_context, similar_errors

# try:
# # ...specific operation
# except FileNotFoundError as e:
# logger.error(f"File not found: {e}")
# # Handle specifically
# except PermissionError as e:
# logger.error(f"Permission denied: {e}")
# # Handle specifically
# except Exception as e:
# logger.exception(f"Unexpected error: {e}")
# # Generic fallback
38 changes: 35 additions & 3 deletions app/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from qdrant_client import models # Add this import
from glob import glob
import json
from app.llm_client import LlamaEdgeClient # Adjust the import based on your project structure

class QdrantStore:
"""Interface for Qdrant vector database"""
Expand Down Expand Up @@ -55,9 +58,9 @@ def insert_documents(self, collection_name: str, embeddings: List[List[float]],
metadata: List[Dict[str, Any]]):
"""Insert documents with embeddings and metadata into collection"""
points = []
for i, (embedding, meta) in enumerate(zip(embeddings, metadata)):
for embedding, meta in zip(embeddings, metadata):
points.append(models.PointStruct(
id=i,
id=str(uuid.uuid4()), # Using UUID instead of index
vector=embedding,
payload=meta
))
Expand Down Expand Up @@ -121,4 +124,33 @@ def count(self, collection_name: str) -> int:
return collection_info.vectors_count
except Exception as e:
print(f"Error getting count for collection {collection_name}: {e}")
return 0
return 0

def load_project_examples():
"""Load project examples into vector database"""
vector_store = QdrantStore()
llm_client = LlamaEdgeClient()

# Ensure collections exist
vector_store.create_collection("project_examples")

example_files = glob("data/project_examples/*.json")

embeddings = []
metadata = []

for file_path in example_files:
with open(file_path, 'r') as f:
example = json.load(f)

# Get embedding for query
embedding = llm_client.get_embeddings([example["query"]])[0]

embeddings.append(embedding)
metadata.append(example)

print(f"Loaded project example: {example['query']}")

# Insert all documents in a single batch
if embeddings:
vector_store.insert_documents("project_examples", embeddings, metadata)