diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ce5ca98 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +# Example environment variables for API keys +DOUBAO_API_KEY=your_doubao_api_key_here +QWEN_API_KEY=your_qwen_api_key_here +GPT_API_KEY=your_gpt_api_key_here +GEMINI_API_KEY=your_gemini_api_key_here diff --git a/README.md b/README.md index 663a73a..204dae9 100644 --- a/README.md +++ b/README.md @@ -33,15 +33,39 @@ Michael R. Lyu2, Xiangyu Yue1✉ It also supports customized modifications, allowing developers and designers to tweak layout and styling with ease. Whether you're prototyping quickly or building pixel-perfect interfaces, ScreenCoder bridges the gap between design and development — just copy, customize, and deploy. +## API Key Configuration + + +ScreenCoder uses multiple AI services (Doubao, Qwen, GPT, Gemini). API keys are managed via environment variables in a `.env` file. For team or open source use, copy `.env.example` to `.env` and fill in your keys: + +``` +DOUBAO_API_KEY=your_doubao_api_key_here +QWEN_API_KEY=your_qwen_api_key_here +GPT_API_KEY=your_gpt_api_key_here +GEMINI_API_KEY=your_gemini_api_key_here +``` + +All code modules use the unified `api_config.py` for API key loading. To switch models, use the `ApiService` enum: + +```python +from api_config import get_api_key, ApiService +client = Doubao(get_api_key(ApiService.DOUBAO)) +# or +client = Qwen(get_api_key(ApiService.QWEN)) +# or +client = GPT(get_api_key(ApiService.GPT)) +# or +client = Gemini(get_api_key(ApiService.GEMINI)) +``` + ## Huggingface Demo -- Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder) +Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder) -- Run the demo locally (download from huggingface space): +Run the demo locally (download from huggingface space): - ```bash - python app.py - ``` - +```bash +python app.py +``` ## Demo Videos A showcase of how **ScreenCoder** transforms UI screenshots into structured, editable HTML/CSS code using a modular multi-agent framework. @@ -101,8 +125,8 @@ As shown above, our method produces results that are more accurate, visually ali pip install -r requirements.txt ``` 4. **Configure the model and API key** - - ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py`. Supported options: Doubao(default), Qwen, GPT, Gemini. - - ***Add the API key***: Create a plain-text file (`doubao_api.txt`, `qwen_api.txt`, `gpt_api.txt`, `gemini_api.txt`) in the project root directory that corresponds to your selected model, and paste your API key inside. + - ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py` by changing the client and using the corresponding `ApiService` enum value. Supported options: Doubao (default), Qwen, GPT, Gemini. + - ***Add the API key***: Copy `.env.example` to `.env` and fill in your API keys for each model as environment variables. No need for separate txt files. ## Usage diff --git a/api_config.py b/api_config.py new file mode 100644 index 0000000..fb3c5ea --- /dev/null +++ b/api_config.py @@ -0,0 +1,21 @@ +import os +from enum import Enum +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + + +class ApiService(Enum): + DOUBAO = "DOUBAO_API_KEY" + QWEN = "QWEN_API_KEY" + GPT = "GPT_API_KEY" + GEMINI = "GEMINI_API_KEY" + + +def get_api_key(service: ApiService): + """ + Get API key for the given service from environment variables. + Usage: get_api_key(ApiService.DOUBAO) + """ + return os.getenv(service.value) diff --git a/block_parsor.py b/block_parsor.py index 7e879b1..ecff2d7 100644 --- a/block_parsor.py +++ b/block_parsor.py @@ -1,10 +1,14 @@ +# Unified API key config +from api_config import get_api_key, ApiService import os import cv2 import json from utils import Doubao, Qwen, GPT, Gemini, encode_image, image_mask DEFAULT_IMAGE_PATH = "data/input/test1.png" -DEFAULT_API_PATH = "doubao_api.txt" # Change the API key path for different models (i.e. doubao, qwen, gpt, gemini). + +# Get Doubao API key from unified config +DEFAULT_API_KEY = get_api_key(ApiService.DOUBAO) # We provide prompts in both Chinese and English. PROMPT_MERGE = "Return the bounding boxes of the sidebar, main content, header, and navigation in this webpage screenshot. Please only return the corresponding bounding boxes. Note: 1. The areas should not overlap; 2. All text information and other content should be framed inside; 3. Try to keep it compact without leaving a lot of blank space; 4. Output a label and the corresponding bounding box for each line." @@ -34,7 +38,7 @@ def resolve_containment(bboxes: dict[str, tuple[int, int, int, int]]) -> dict[st If a box is found to be fully contained within another, it is removed. This is based on the assumption that major layout components should not contain each other. """ - + def contains(box_a, box_b): """Checks if box_a completely contains box_b.""" xa1, ya1, xa2, ya2 = box_a @@ -48,12 +52,13 @@ def contains(box_a, box_b): for j in range(len(names)): if i == j or names[i] in removed or names[j] in removed: continue - + name1, box1 = names[i], bboxes[names[i]] name2, box2 = names[j], bboxes[names[j]] if contains(box1, box2) or contains(box2, box1): - print(f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.") + print( + f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.") removed.add(name2) return {name: bbox for name, bbox in bboxes.items() if name not in removed} @@ -71,7 +76,7 @@ def contains(box_a, box_b): # for j in range(len(names)): # if i == j or names[i] in removed or names[j] in removed: # continue - + # box1, box2 = bboxes[names[i]], bboxes[names[j]] # iou_score = iou(box1, box2) # if iou_score > 0: @@ -92,6 +97,8 @@ def contains(box_a, box_b): # return intersection_area / (box1_area + box2_area - intersection_area) # simple version of bbox parsing + + def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, int, int]]: """Parse bounding box string to dictionary of named coordinate tuples""" bboxes = {} @@ -102,16 +109,16 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, print(f"Error: Failed to read image {image_path}") return bboxes h, w = image.shape[:2] - + try: components = bbox_input.strip().split('\n') # print("Split components:", components) # Debug print - + for component in components: component = component.strip() if not component: continue - + if ':' in component: name, bbox_str = component.split(':', 1) else: @@ -126,17 +133,17 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, name = 'main content' else: name = 'unknown' - + name = name.strip().lower() bbox_str = bbox_str.strip() - + # print(f"Processing component: {name}, bbox_str: {bbox_str}") # Debug print - + if BBOX_TAG_START in bbox_str and BBOX_TAG_END in bbox_str: start_idx = bbox_str.find(BBOX_TAG_START) + len(BBOX_TAG_START) end_idx = bbox_str.find(BBOX_TAG_END) coords_str = bbox_str[start_idx:end_idx].strip() - + try: norm_coords = list(map(int, coords_str.split())) if len(norm_coords) == 4: @@ -147,27 +154,29 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, bboxes[name] = (x_min, y_min, x_max, y_max) print(f"Successfully parsed {name}: {bboxes[name]}") else: - print(f"Invalid number of coordinates for {name}: {norm_coords}") + print( + f"Invalid number of coordinates for {name}: {norm_coords}") except ValueError as e: print(f"Failed to parse coordinates for {name}: {e}") else: print(f"No bbox tags found in: {bbox_str}") - + except Exception as e: print(f"Coordinate parsing failed: {str(e)}") import traceback traceback.print_exc() - + print("Final parsed bboxes:", bboxes) return bboxes + def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) -> str: """Draw bounding boxes on image and save with different colors for each component""" image = cv2.imread(image_path) if image is None: print(f"Error: Failed to read image {image_path}") - return "" - + return "" + h, w = image.shape[:2] colors = { 'sidebar': (0, 0, 255), # Red @@ -176,46 +185,49 @@ def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) - 'main content': (255, 255, 0), # Cyan 'unknown': (0, 0, 0), # Black } - + for component, norm_bbox in bboxes.items(): # Convert normalized coordinates to pixel coordinates for drawing x_min = int(norm_bbox[0] * w / 1000) y_min = int(norm_bbox[1] * h / 1000) x_max = int(norm_bbox[2] * w / 1000) y_max = int(norm_bbox[3] * h / 1000) - + color = colors.get(component.lower(), (0, 0, 255)) cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) - + # Add label cv2.putText(image, component, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) - + # Output directory output_dir = "data/tmp" os.makedirs(output_dir, exist_ok=True) - + # Get the original filename without path original_filename = os.path.basename(image_path) - output_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_with_bboxes.png") - + output_path = os.path.join(output_dir, os.path.splitext( + original_filename)[0] + "_with_bboxes.png") + if cv2.imwrite(output_path, image): print(f"Successfully saved annotated image: {output_path}") return output_path print("Error: Failed to save image") return "" + def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> str: """Save bounding boxes information to a JSON file""" # Output directory output_dir = "data/tmp" os.makedirs(output_dir, exist_ok=True) - + original_filename = os.path.basename(image_path) - json_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_bboxes.json") - + json_path = os.path.join(output_dir, os.path.splitext( + original_filename)[0] + "_bboxes.json") + bboxes_dict = {k: list(v) for k, v in bboxes.items()} - + try: with open(json_path, 'w', encoding='utf-8') as f: json.dump(bboxes_dict, f, indent=4, ensure_ascii=False) @@ -233,13 +245,13 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # bboxes = {} # current_image_path = image_path # ark_client = Doubao(api_path) # Change your client according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path) - + # image = cv2.imread(image_path) # if image is None: # print(f"Error: Failed to read image {image_path}") # return bboxes # h, w = image.shape[:2] - + # for i, (component_name, prompt) in enumerate(PROMPT_LIST): # print(f"\n=== Processing {component_name} (Step {i+1}/{len(PROMPT_LIST)}) ===") @@ -252,22 +264,22 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # bbox_content = ark_client.ask(prompt, base64_image) # print(f"Model response for {component_name}:") # print(bbox_content) - + # norm_bbox = parse_single_bbox(bbox_content, component_name) # if norm_bbox: # bboxes[component_name] = norm_bbox # print(f"Successfully detected {component_name}: {norm_bbox}") - + # masked_image = image_mask(current_image_path, norm_bbox) - + # temp_image_path = f"data/temp_{component_name}_masked.png" # masked_image.save(temp_image_path) # current_image_path = temp_image_path - + # print(f"Created masked image for next step: {temp_image_path}") # else: # print(f"Failed to detect {component_name}") - + # return bboxes # def parse_single_bbox(bbox_input: str, component_name: str) -> tuple[int, int, int, int]: @@ -275,13 +287,13 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # Parses a single component's bbox string and returns normalized coordinates. # """ # print(f"Parsing bbox for {component_name}: {bbox_input}") - + # try: # if BBOX_TAG_START in bbox_input and BBOX_TAG_END in bbox_input: # start_idx = bbox_input.find(BBOX_TAG_START) + len(BBOX_TAG_START) # end_idx = bbox_input.find(BBOX_TAG_END) # coords_str = bbox_input[start_idx:end_idx].strip() - + # norm_coords = list(map(int, coords_str.split())) # if len(norm_coords) == 4: # return tuple(norm_coords) @@ -291,7 +303,7 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # print(f"No bbox tags found in response for {component_name}") # except Exception as e: # print(f"Failed to parse bbox for {component_name}: {e}") - + # return None # def main_content_processing(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> dict[str, tuple[int, int, int, int]]: @@ -307,16 +319,18 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # int(bbox[1] * h / 1000), # int(bbox[2] * w / 1000), # int(bbox[3] * h / 1000)) - - + + if __name__ == "__main__": image_path = DEFAULT_IMAGE_PATH - api_path = DEFAULT_API_PATH + api_key = DEFAULT_API_KEY print("=== Starting Simple Component Detection ===") print(f"Input image: {image_path}") - print(f"API path: {api_path}") - client = Doubao(api_path) # Change your models according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path) + print(f"API key: {api_key}") + # You can switch model by changing the client and API key source + # For example: Qwen(get_api_key(ApiService.QWEN)), GPT(get_api_key(ApiService.GPT)), Gemini(get_api_key(ApiService.GEMINI)) + client = Doubao(api_key) bbox_content = client.ask(PROMPT_MERGE, encode_image(image_path)) print(f"Model response: {bbox_content}\n") bboxes = parse_bboxes(bbox_content, image_path) @@ -325,17 +339,17 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # print(f"Input image: {image_path}") # print(f"API path: {api_path}") # bboxes = sequential_component_detection(image_path, api_path) - + bboxes = resolve_containment(bboxes) - + if bboxes: print(f"\n=== Detection Complete ===") print(f"Found bounding boxes for components: {list(bboxes.keys())}") print(f"Total components detected: {len(bboxes)}") - + json_path = save_bboxes_to_json(bboxes, image_path) draw_bboxes(image_path, bboxes) - + print(f"\n=== Results ===") for component, bbox in bboxes.items(): print(f"{component}: {bbox}") diff --git a/html_generator.py b/html_generator.py index 3a14e5f..3eadb18 100644 --- a/html_generator.py +++ b/html_generator.py @@ -1,3 +1,4 @@ +from api_config import get_api_key, ApiService from utils import encode_image, Doubao, Qwen, GPT, Gemini from PIL import Image import bs4 @@ -54,22 +55,22 @@ # Please fill in a complete HTML and Tailwind CSS code to accurately reproduce the given container. # Please ensure that all block layouts, icon styles, sizes, and text information are consistent with the original screenshot, # based on the user's additional conditions. Below is the code template to fill in: - + #