From 8d8ade0023484230c660c1fc23ea69bd6d7af848 Mon Sep 17 00:00:00 2001 From: "along.li" Date: Tue, 12 Aug 2025 16:10:14 +0400 Subject: [PATCH] Add API key configuration and update documentation for environment variables --- .env.example | 5 +++ README.md | 40 +++++++++++++---- api_config.py | 21 +++++++++ block_parsor.py | 108 ++++++++++++++++++++++++++-------------------- html_generator.py | 84 ++++++++++++++++++++++-------------- 5 files changed, 170 insertions(+), 88 deletions(-) create mode 100644 .env.example create mode 100644 api_config.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ce5ca98 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +# Example environment variables for API keys +DOUBAO_API_KEY=your_doubao_api_key_here +QWEN_API_KEY=your_qwen_api_key_here +GPT_API_KEY=your_gpt_api_key_here +GEMINI_API_KEY=your_gemini_api_key_here diff --git a/README.md b/README.md index 663a73a..204dae9 100644 --- a/README.md +++ b/README.md @@ -33,15 +33,39 @@ Michael R. Lyu2, Xiangyu Yue1✉ It also supports customized modifications, allowing developers and designers to tweak layout and styling with ease. Whether you're prototyping quickly or building pixel-perfect interfaces, ScreenCoder bridges the gap between design and development — just copy, customize, and deploy. +## API Key Configuration + + +ScreenCoder uses multiple AI services (Doubao, Qwen, GPT, Gemini). API keys are managed via environment variables in a `.env` file. For team or open source use, copy `.env.example` to `.env` and fill in your keys: + +``` +DOUBAO_API_KEY=your_doubao_api_key_here +QWEN_API_KEY=your_qwen_api_key_here +GPT_API_KEY=your_gpt_api_key_here +GEMINI_API_KEY=your_gemini_api_key_here +``` + +All code modules use the unified `api_config.py` for API key loading. To switch models, use the `ApiService` enum: + +```python +from api_config import get_api_key, ApiService +client = Doubao(get_api_key(ApiService.DOUBAO)) +# or +client = Qwen(get_api_key(ApiService.QWEN)) +# or +client = GPT(get_api_key(ApiService.GPT)) +# or +client = Gemini(get_api_key(ApiService.GEMINI)) +``` + ## Huggingface Demo -- Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder) +Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder) -- Run the demo locally (download from huggingface space): +Run the demo locally (download from huggingface space): - ```bash - python app.py - ``` - +```bash +python app.py +``` ## Demo Videos A showcase of how **ScreenCoder** transforms UI screenshots into structured, editable HTML/CSS code using a modular multi-agent framework. @@ -101,8 +125,8 @@ As shown above, our method produces results that are more accurate, visually ali pip install -r requirements.txt ``` 4. **Configure the model and API key** - - ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py`. Supported options: Doubao(default), Qwen, GPT, Gemini. - - ***Add the API key***: Create a plain-text file (`doubao_api.txt`, `qwen_api.txt`, `gpt_api.txt`, `gemini_api.txt`) in the project root directory that corresponds to your selected model, and paste your API key inside. + - ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py` by changing the client and using the corresponding `ApiService` enum value. Supported options: Doubao (default), Qwen, GPT, Gemini. + - ***Add the API key***: Copy `.env.example` to `.env` and fill in your API keys for each model as environment variables. No need for separate txt files. ## Usage diff --git a/api_config.py b/api_config.py new file mode 100644 index 0000000..fb3c5ea --- /dev/null +++ b/api_config.py @@ -0,0 +1,21 @@ +import os +from enum import Enum +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + + +class ApiService(Enum): + DOUBAO = "DOUBAO_API_KEY" + QWEN = "QWEN_API_KEY" + GPT = "GPT_API_KEY" + GEMINI = "GEMINI_API_KEY" + + +def get_api_key(service: ApiService): + """ + Get API key for the given service from environment variables. + Usage: get_api_key(ApiService.DOUBAO) + """ + return os.getenv(service.value) diff --git a/block_parsor.py b/block_parsor.py index 7e879b1..ecff2d7 100644 --- a/block_parsor.py +++ b/block_parsor.py @@ -1,10 +1,14 @@ +# Unified API key config +from api_config import get_api_key, ApiService import os import cv2 import json from utils import Doubao, Qwen, GPT, Gemini, encode_image, image_mask DEFAULT_IMAGE_PATH = "data/input/test1.png" -DEFAULT_API_PATH = "doubao_api.txt" # Change the API key path for different models (i.e. doubao, qwen, gpt, gemini). + +# Get Doubao API key from unified config +DEFAULT_API_KEY = get_api_key(ApiService.DOUBAO) # We provide prompts in both Chinese and English. PROMPT_MERGE = "Return the bounding boxes of the sidebar, main content, header, and navigation in this webpage screenshot. Please only return the corresponding bounding boxes. Note: 1. The areas should not overlap; 2. All text information and other content should be framed inside; 3. Try to keep it compact without leaving a lot of blank space; 4. Output a label and the corresponding bounding box for each line." @@ -34,7 +38,7 @@ def resolve_containment(bboxes: dict[str, tuple[int, int, int, int]]) -> dict[st If a box is found to be fully contained within another, it is removed. This is based on the assumption that major layout components should not contain each other. """ - + def contains(box_a, box_b): """Checks if box_a completely contains box_b.""" xa1, ya1, xa2, ya2 = box_a @@ -48,12 +52,13 @@ def contains(box_a, box_b): for j in range(len(names)): if i == j or names[i] in removed or names[j] in removed: continue - + name1, box1 = names[i], bboxes[names[i]] name2, box2 = names[j], bboxes[names[j]] if contains(box1, box2) or contains(box2, box1): - print(f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.") + print( + f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.") removed.add(name2) return {name: bbox for name, bbox in bboxes.items() if name not in removed} @@ -71,7 +76,7 @@ def contains(box_a, box_b): # for j in range(len(names)): # if i == j or names[i] in removed or names[j] in removed: # continue - + # box1, box2 = bboxes[names[i]], bboxes[names[j]] # iou_score = iou(box1, box2) # if iou_score > 0: @@ -92,6 +97,8 @@ def contains(box_a, box_b): # return intersection_area / (box1_area + box2_area - intersection_area) # simple version of bbox parsing + + def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, int, int]]: """Parse bounding box string to dictionary of named coordinate tuples""" bboxes = {} @@ -102,16 +109,16 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, print(f"Error: Failed to read image {image_path}") return bboxes h, w = image.shape[:2] - + try: components = bbox_input.strip().split('\n') # print("Split components:", components) # Debug print - + for component in components: component = component.strip() if not component: continue - + if ':' in component: name, bbox_str = component.split(':', 1) else: @@ -126,17 +133,17 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, name = 'main content' else: name = 'unknown' - + name = name.strip().lower() bbox_str = bbox_str.strip() - + # print(f"Processing component: {name}, bbox_str: {bbox_str}") # Debug print - + if BBOX_TAG_START in bbox_str and BBOX_TAG_END in bbox_str: start_idx = bbox_str.find(BBOX_TAG_START) + len(BBOX_TAG_START) end_idx = bbox_str.find(BBOX_TAG_END) coords_str = bbox_str[start_idx:end_idx].strip() - + try: norm_coords = list(map(int, coords_str.split())) if len(norm_coords) == 4: @@ -147,27 +154,29 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, bboxes[name] = (x_min, y_min, x_max, y_max) print(f"Successfully parsed {name}: {bboxes[name]}") else: - print(f"Invalid number of coordinates for {name}: {norm_coords}") + print( + f"Invalid number of coordinates for {name}: {norm_coords}") except ValueError as e: print(f"Failed to parse coordinates for {name}: {e}") else: print(f"No bbox tags found in: {bbox_str}") - + except Exception as e: print(f"Coordinate parsing failed: {str(e)}") import traceback traceback.print_exc() - + print("Final parsed bboxes:", bboxes) return bboxes + def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) -> str: """Draw bounding boxes on image and save with different colors for each component""" image = cv2.imread(image_path) if image is None: print(f"Error: Failed to read image {image_path}") - return "" - + return "" + h, w = image.shape[:2] colors = { 'sidebar': (0, 0, 255), # Red @@ -176,46 +185,49 @@ def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) - 'main content': (255, 255, 0), # Cyan 'unknown': (0, 0, 0), # Black } - + for component, norm_bbox in bboxes.items(): # Convert normalized coordinates to pixel coordinates for drawing x_min = int(norm_bbox[0] * w / 1000) y_min = int(norm_bbox[1] * h / 1000) x_max = int(norm_bbox[2] * w / 1000) y_max = int(norm_bbox[3] * h / 1000) - + color = colors.get(component.lower(), (0, 0, 255)) cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) - + # Add label cv2.putText(image, component, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) - + # Output directory output_dir = "data/tmp" os.makedirs(output_dir, exist_ok=True) - + # Get the original filename without path original_filename = os.path.basename(image_path) - output_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_with_bboxes.png") - + output_path = os.path.join(output_dir, os.path.splitext( + original_filename)[0] + "_with_bboxes.png") + if cv2.imwrite(output_path, image): print(f"Successfully saved annotated image: {output_path}") return output_path print("Error: Failed to save image") return "" + def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> str: """Save bounding boxes information to a JSON file""" # Output directory output_dir = "data/tmp" os.makedirs(output_dir, exist_ok=True) - + original_filename = os.path.basename(image_path) - json_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_bboxes.json") - + json_path = os.path.join(output_dir, os.path.splitext( + original_filename)[0] + "_bboxes.json") + bboxes_dict = {k: list(v) for k, v in bboxes.items()} - + try: with open(json_path, 'w', encoding='utf-8') as f: json.dump(bboxes_dict, f, indent=4, ensure_ascii=False) @@ -233,13 +245,13 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # bboxes = {} # current_image_path = image_path # ark_client = Doubao(api_path) # Change your client according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path) - + # image = cv2.imread(image_path) # if image is None: # print(f"Error: Failed to read image {image_path}") # return bboxes # h, w = image.shape[:2] - + # for i, (component_name, prompt) in enumerate(PROMPT_LIST): # print(f"\n=== Processing {component_name} (Step {i+1}/{len(PROMPT_LIST)}) ===") @@ -252,22 +264,22 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # bbox_content = ark_client.ask(prompt, base64_image) # print(f"Model response for {component_name}:") # print(bbox_content) - + # norm_bbox = parse_single_bbox(bbox_content, component_name) # if norm_bbox: # bboxes[component_name] = norm_bbox # print(f"Successfully detected {component_name}: {norm_bbox}") - + # masked_image = image_mask(current_image_path, norm_bbox) - + # temp_image_path = f"data/temp_{component_name}_masked.png" # masked_image.save(temp_image_path) # current_image_path = temp_image_path - + # print(f"Created masked image for next step: {temp_image_path}") # else: # print(f"Failed to detect {component_name}") - + # return bboxes # def parse_single_bbox(bbox_input: str, component_name: str) -> tuple[int, int, int, int]: @@ -275,13 +287,13 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # Parses a single component's bbox string and returns normalized coordinates. # """ # print(f"Parsing bbox for {component_name}: {bbox_input}") - + # try: # if BBOX_TAG_START in bbox_input and BBOX_TAG_END in bbox_input: # start_idx = bbox_input.find(BBOX_TAG_START) + len(BBOX_TAG_START) # end_idx = bbox_input.find(BBOX_TAG_END) # coords_str = bbox_input[start_idx:end_idx].strip() - + # norm_coords = list(map(int, coords_str.split())) # if len(norm_coords) == 4: # return tuple(norm_coords) @@ -291,7 +303,7 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # print(f"No bbox tags found in response for {component_name}") # except Exception as e: # print(f"Failed to parse bbox for {component_name}: {e}") - + # return None # def main_content_processing(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> dict[str, tuple[int, int, int, int]]: @@ -307,16 +319,18 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # int(bbox[1] * h / 1000), # int(bbox[2] * w / 1000), # int(bbox[3] * h / 1000)) - - + + if __name__ == "__main__": image_path = DEFAULT_IMAGE_PATH - api_path = DEFAULT_API_PATH + api_key = DEFAULT_API_KEY print("=== Starting Simple Component Detection ===") print(f"Input image: {image_path}") - print(f"API path: {api_path}") - client = Doubao(api_path) # Change your models according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path) + print(f"API key: {api_key}") + # You can switch model by changing the client and API key source + # For example: Qwen(get_api_key(ApiService.QWEN)), GPT(get_api_key(ApiService.GPT)), Gemini(get_api_key(ApiService.GEMINI)) + client = Doubao(api_key) bbox_content = client.ask(PROMPT_MERGE, encode_image(image_path)) print(f"Model response: {bbox_content}\n") bboxes = parse_bboxes(bbox_content, image_path) @@ -325,17 +339,17 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path # print(f"Input image: {image_path}") # print(f"API path: {api_path}") # bboxes = sequential_component_detection(image_path, api_path) - + bboxes = resolve_containment(bboxes) - + if bboxes: print(f"\n=== Detection Complete ===") print(f"Found bounding boxes for components: {list(bboxes.keys())}") print(f"Total components detected: {len(bboxes)}") - + json_path = save_bboxes_to_json(bboxes, image_path) draw_bboxes(image_path, bboxes) - + print(f"\n=== Results ===") for component, bbox in bboxes.items(): print(f"{component}: {bbox}") diff --git a/html_generator.py b/html_generator.py index 3a14e5f..3eadb18 100644 --- a/html_generator.py +++ b/html_generator.py @@ -1,3 +1,4 @@ +from api_config import get_api_key, ApiService from utils import encode_image, Doubao, Qwen, GPT, Gemini from PIL import Image import bs4 @@ -54,22 +55,22 @@ # Please fill in a complete HTML and Tailwind CSS code to accurately reproduce the given container. # Please ensure that all block layouts, icon styles, sizes, and text information are consistent with the original screenshot, # based on the user's additional conditions. Below is the code template to fill in: - + #
# your code here #
- + # Only return the code within the
and
tags.""", # "header": f"""This is a screenshot of a container. Here is the user's additional instruction: {user_instruction["header"]} # Please fill in a complete HTML and Tailwind CSS code to accurately reproduce the given container. # Please ensure that all blocks' relative positions, layout, text information, and colors within the bounding box # are consistent with the original screenshot, based on the user's additional conditions. Below is the code template to fill in: - + #
# your code here #
- + # Only return the code within the
and
tags.""", # "navigation": f"""This is a screenshot of a container. Here is the user's additional instruction: {user_instruction["navigation"]} @@ -77,11 +78,11 @@ # Please ensure that all blocks' relative positions, text layout, and colors within the bounding box # are consistent with the original screenshot, based on the user's additional conditions. # Please use the same icons as in the original screenshot. Below is the code template to fill in: - + #
# your code here #
- + # Only return the code within the
and
tags.""", # "main content": f"""This is a screenshot of a container. Here is the user's additional instruction: {user_instruction["main content"]} @@ -90,11 +91,11 @@ # text inside the images does not need to be recognized. # Please ensure that all blocks' relative positions, layout, text information, and colors within the bounding box # are consistent with the original screenshot, based on the user's additional conditions. Below is the code template to fill in: - + #
# your code here #
- + # Only return the code within the
and
tags.""" # } @@ -102,17 +103,19 @@ # PROMPT_refinement = """Here is a prototype image of a webpage. I have an draft HTML file that contains most of the elements and their correct positions, but it has *inaccurate background*, and some missing or wrong elements. Please compare the draft and the prototype image, then revise the draft implementation. Return a single piece of accurate HTML+tail-wind CSS code to reproduce the website. Respond with the content of the HTML+tail-wind CSS code. The current implementation I have is: \n\n [CODE]""" # Generate code for each component + + def generate_code(bbox_tree, img_path, bot): """generate code for all the leaf nodes in the bounding box tree, return a dictionary: {'id': 'code'}""" img = Image.open(img_path) code_dict = {} - + def _generate_code(node): if node["children"] == []: bbox = node["bbox"] # bbox is already in pixel coordinates [x1, y1, x2, y2] cropped_img = img.crop(bbox) - + # Select prompt based on node type if "type" in node: if node["type"] == "sidebar": @@ -129,12 +132,13 @@ def _generate_code(node): else: print("Node type not found") return - + try: code = bot.ask(prompt, encode_image(cropped_img)) code_dict[node["id"]] = code except Exception as e: - print(f"Error generating code for {node.get('type', 'unknown')}: {str(e)}") + print( + f"Error generating code for {node.get('type', 'unknown')}: {str(e)}") code_dict[node["id"]] = f"" else: for child in node["children"]: @@ -144,11 +148,13 @@ def _generate_code(node): return code_dict # Generate code for each component in parallel + + def generate_code_parallel(bbox_tree, img_path, bot): """generate code for all the leaf nodes in the bounding box tree, return a dictionary: {'id': 'code'}""" code_dict = {} t_list = [] - + def _generate_code_with_retry(node, max_retries=3, retry_delay=2): """Generate code with retry mechanism for rate limit errors""" try: @@ -163,13 +169,14 @@ def _generate_code_with_retry(node, max_retries=3, retry_delay=2): prompt = PROMPT_DICT[node["type"]] else: print(f"Unknown component type: {node['type']}") - code_dict[node["id"]] = f"" + code_dict[node["id"] + ] = f"" return else: print("Node type not found") code_dict[node["id"]] = f"" return - + for attempt in range(max_retries): try: code = bot.ask(prompt, encode_image(cropped_img)) @@ -177,11 +184,13 @@ def _generate_code_with_retry(node, max_retries=3, retry_delay=2): return except Exception as e: if "rate_limit" in str(e).lower() and attempt < max_retries - 1: - print(f"Rate limit hit, retrying in {retry_delay} seconds... (Attempt {attempt + 1}/{max_retries})") + print( + f"Rate limit hit, retrying in {retry_delay} seconds... (Attempt {attempt + 1}/{max_retries})") time.sleep(retry_delay) retry_delay *= 2 # Exponential backoff else: - print(f"Error generating code for node {node['id']}: {str(e)}") + print( + f"Error generating code for node {node['id']}: {str(e)}") code_dict[node["id"]] = f"" return except Exception as e: @@ -198,14 +207,16 @@ def _generate_code(node): _generate_code(child) _generate_code(bbox_tree) - + # Wait for all threads to complete for t in t_list: t.join() - + return code_dict # Generate HTML from the bounding box tree + + def generate_html(bbox_tree, output_file="output.html", img_path="data/test1.png"): """ Generates an HTML file with nested containers based on the bounding box tree. @@ -284,11 +295,12 @@ def process_bbox(node, parent_width, parent_height, parent_left, parent_top, img current_width = bbox[2] - bbox[0] current_height = bbox[3] - bbox[1] for child in children: - html += process_bbox(child, current_width, current_height, bbox[0], bbox[1], img) + html += process_bbox(child, current_width, + current_height, bbox[0], bbox[1], img) html += ''' ''' - + # Close the box div html += ''' @@ -304,7 +316,8 @@ def process_bbox(node, parent_width, parent_height, parent_left, parent_top, img html_content = html_template_start for child in root_children: - html_content += process_bbox(child, root_width, root_height, root_x, root_y, img) + html_content += process_bbox(child, root_width, + root_height, root_x, root_y, img) html_content += html_template_end soup = bs4.BeautifulSoup(html_content, 'html.parser') @@ -314,6 +327,8 @@ def process_bbox(node, parent_width, parent_height, parent_left, parent_top, img f.write(html_content) # Substitute the code in the html file + + def code_substitution(html_file, code_dict): """substitute the code in the html file""" with open(html_file, "r") as f: @@ -346,26 +361,26 @@ def code_substitution(html_file, code_dict): # except Exception as e: # print(f"An error occurred during HTML refinement: {e}") + # Main if __name__ == "__main__": import json import time from PIL import Image - + # Load bboxes from block_parsing.py output boxes_data = json.load(open("data/tmp/test1_bboxes.json")) - img_path = "data/input/test1.png" with Image.open(img_path) as img: width, height = img.size - + # Create root node with actual image dimensions root = { "bbox": [0, 0, width, height], # Use actual image dimensions "children": [] } - + # Add each region as a child with its type for component_name, norm_bbox in boxes_data.items(): # The coordinates from block_parsor are normalized to 1000x1000 @@ -374,21 +389,21 @@ def code_substitution(html_file, code_dict): y1 = int(norm_bbox[1] * height / 1000) x2 = int(norm_bbox[2] * width / 1000) y2 = int(norm_bbox[3] * height / 1000) - + child = { "bbox": [x1, y1, x2, y2], "children": [], "type": component_name } root["children"].append(child) - + # Assign IDs to all nodes def assign_id(node, id): node["id"] = id for child in node.get("children", []): id = assign_id(child, id+1) return id - + assign_id(root, 0) # print(root) @@ -396,17 +411,20 @@ def assign_id(node, id): generate_html(root, 'data/tmp/test1_layout.html') # Initialize the bot - # Change your model & API ket path according to your needs - bot = Doubao("doubao_api.txt", model = "doubao-1.5-thinking-vision-pro-250428") + # Get Doubao API key from unified config + doubao_api_key = get_api_key(ApiService.DOUBAO) + bot = Doubao(doubao_api_key, model="doubao-1.5-thinking-vision-pro-250428") + # You can switch model by changing the client and API key source + # For example: Qwen(get_api_key(ApiService.QWEN)), GPT(get_api_key(ApiService.GPT)), Gemini(get_api_key(ApiService.GEMINI)) # bot = Qwen("qwen_api.txt", model="qwen2.5-vl-72b-instruct") # bot = GPT("gpt_api.txt", model="gpt-4o") # bot = Gemini("gemini_api.txt", model="gemini-1.5-flash-latest") - + # Generate code for each component # code_dict = generate_code(root, img_path, bot) code_dict = generate_code_parallel(root, img_path, bot) - + # Substitute the generated code into the HTML code_substitution('data/tmp/test1_layout.html', code_dict)