Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Example environment variables for API keys
DOUBAO_API_KEY=your_doubao_api_key_here
QWEN_API_KEY=your_qwen_api_key_here
GPT_API_KEY=your_gpt_api_key_here
GEMINI_API_KEY=your_gemini_api_key_here
40 changes: 32 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,39 @@ Michael R. Lyu<sup>2</sup>, Xiangyu Yue<sup>1✉</sup>

It also supports customized modifications, allowing developers and designers to tweak layout and styling with ease. Whether you're prototyping quickly or building pixel-perfect interfaces, ScreenCoder bridges the gap between design and development — just copy, customize, and deploy.

## API Key Configuration


ScreenCoder uses multiple AI services (Doubao, Qwen, GPT, Gemini). API keys are managed via environment variables in a `.env` file. For team or open source use, copy `.env.example` to `.env` and fill in your keys:

```
DOUBAO_API_KEY=your_doubao_api_key_here
QWEN_API_KEY=your_qwen_api_key_here
GPT_API_KEY=your_gpt_api_key_here
GEMINI_API_KEY=your_gemini_api_key_here
```

All code modules use the unified `api_config.py` for API key loading. To switch models, use the `ApiService` enum:

```python
from api_config import get_api_key, ApiService
client = Doubao(get_api_key(ApiService.DOUBAO))
# or
client = Qwen(get_api_key(ApiService.QWEN))
# or
client = GPT(get_api_key(ApiService.GPT))
# or
client = Gemini(get_api_key(ApiService.GEMINI))
```

## Huggingface Demo
- Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder)
Try our huggingface demo at [Demo](https://huggingface.co/spaces/Jimmyzheng-10/ScreenCoder)

- Run the demo locally (download from huggingface space):
Run the demo locally (download from huggingface space):

```bash
python app.py
```

```bash
python app.py
```
## Demo Videos

A showcase of how **ScreenCoder** transforms UI screenshots into structured, editable HTML/CSS code using a modular multi-agent framework.
Expand Down Expand Up @@ -101,8 +125,8 @@ As shown above, our method produces results that are more accurate, visually ali
pip install -r requirements.txt
```
4. **Configure the model and API key**
- ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py`. Supported options: Doubao(default), Qwen, GPT, Gemini.
- ***Add the API key***: Create a plain-text file (`doubao_api.txt`, `qwen_api.txt`, `gpt_api.txt`, `gemini_api.txt`) in the project root directory that corresponds to your selected model, and paste your API key inside.
- ***Choose a generation model***: Set the desired model in `block_parsor.py` and `html_generator.py` by changing the client and using the corresponding `ApiService` enum value. Supported options: Doubao (default), Qwen, GPT, Gemini.
- ***Add the API key***: Copy `.env.example` to `.env` and fill in your API keys for each model as environment variables. No need for separate txt files.

## Usage

Expand Down
21 changes: 21 additions & 0 deletions api_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from enum import Enum
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()


class ApiService(Enum):
DOUBAO = "DOUBAO_API_KEY"
QWEN = "QWEN_API_KEY"
GPT = "GPT_API_KEY"
GEMINI = "GEMINI_API_KEY"


def get_api_key(service: ApiService):
"""
Get API key for the given service from environment variables.
Usage: get_api_key(ApiService.DOUBAO)
"""
return os.getenv(service.value)
108 changes: 61 additions & 47 deletions block_parsor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Unified API key config
from api_config import get_api_key, ApiService
import os
import cv2
import json
from utils import Doubao, Qwen, GPT, Gemini, encode_image, image_mask

DEFAULT_IMAGE_PATH = "data/input/test1.png"
DEFAULT_API_PATH = "doubao_api.txt" # Change the API key path for different models (i.e. doubao, qwen, gpt, gemini).

# Get Doubao API key from unified config
DEFAULT_API_KEY = get_api_key(ApiService.DOUBAO)

# We provide prompts in both Chinese and English.
PROMPT_MERGE = "Return the bounding boxes of the sidebar, main content, header, and navigation in this webpage screenshot. Please only return the corresponding bounding boxes. Note: 1. The areas should not overlap; 2. All text information and other content should be framed inside; 3. Try to keep it compact without leaving a lot of blank space; 4. Output a label and the corresponding bounding box for each line."
Expand Down Expand Up @@ -34,7 +38,7 @@ def resolve_containment(bboxes: dict[str, tuple[int, int, int, int]]) -> dict[st
If a box is found to be fully contained within another, it is removed.
This is based on the assumption that major layout components should not contain each other.
"""

def contains(box_a, box_b):
"""Checks if box_a completely contains box_b."""
xa1, ya1, xa2, ya2 = box_a
Expand All @@ -48,12 +52,13 @@ def contains(box_a, box_b):
for j in range(len(names)):
if i == j or names[i] in removed or names[j] in removed:
continue

name1, box1 = names[i], bboxes[names[i]]
name2, box2 = names[j], bboxes[names[j]]

if contains(box1, box2) or contains(box2, box1):
print(f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.")
print(
f"Containment found: '{name1}' contains '{name2}'. Removing '{name2}'.")
removed.add(name2)

return {name: bbox for name, bbox in bboxes.items() if name not in removed}
Expand All @@ -71,7 +76,7 @@ def contains(box_a, box_b):
# for j in range(len(names)):
# if i == j or names[i] in removed or names[j] in removed:
# continue

# box1, box2 = bboxes[names[i]], bboxes[names[j]]
# iou_score = iou(box1, box2)
# if iou_score > 0:
Expand All @@ -92,6 +97,8 @@ def contains(box_a, box_b):
# return intersection_area / (box1_area + box2_area - intersection_area)

# simple version of bbox parsing


def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, int, int]]:
"""Parse bounding box string to dictionary of named coordinate tuples"""
bboxes = {}
Expand All @@ -102,16 +109,16 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int,
print(f"Error: Failed to read image {image_path}")
return bboxes
h, w = image.shape[:2]

try:
components = bbox_input.strip().split('\n')
# print("Split components:", components) # Debug print

for component in components:
component = component.strip()
if not component:
continue

if ':' in component:
name, bbox_str = component.split(':', 1)
else:
Expand All @@ -126,17 +133,17 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int,
name = 'main content'
else:
name = 'unknown'

name = name.strip().lower()
bbox_str = bbox_str.strip()

# print(f"Processing component: {name}, bbox_str: {bbox_str}") # Debug print

if BBOX_TAG_START in bbox_str and BBOX_TAG_END in bbox_str:
start_idx = bbox_str.find(BBOX_TAG_START) + len(BBOX_TAG_START)
end_idx = bbox_str.find(BBOX_TAG_END)
coords_str = bbox_str[start_idx:end_idx].strip()

try:
norm_coords = list(map(int, coords_str.split()))
if len(norm_coords) == 4:
Expand All @@ -147,27 +154,29 @@ def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int,
bboxes[name] = (x_min, y_min, x_max, y_max)
print(f"Successfully parsed {name}: {bboxes[name]}")
else:
print(f"Invalid number of coordinates for {name}: {norm_coords}")
print(
f"Invalid number of coordinates for {name}: {norm_coords}")
except ValueError as e:
print(f"Failed to parse coordinates for {name}: {e}")
else:
print(f"No bbox tags found in: {bbox_str}")

except Exception as e:
print(f"Coordinate parsing failed: {str(e)}")
import traceback
traceback.print_exc()

print("Final parsed bboxes:", bboxes)
return bboxes


def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) -> str:
"""Draw bounding boxes on image and save with different colors for each component"""
image = cv2.imread(image_path)
if image is None:
print(f"Error: Failed to read image {image_path}")
return ""
return ""

h, w = image.shape[:2]
colors = {
'sidebar': (0, 0, 255), # Red
Expand All @@ -176,46 +185,49 @@ def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) -
'main content': (255, 255, 0), # Cyan
'unknown': (0, 0, 0), # Black
}

for component, norm_bbox in bboxes.items():
# Convert normalized coordinates to pixel coordinates for drawing
x_min = int(norm_bbox[0] * w / 1000)
y_min = int(norm_bbox[1] * h / 1000)
x_max = int(norm_bbox[2] * w / 1000)
y_max = int(norm_bbox[3] * h / 1000)

color = colors.get(component.lower(), (0, 0, 255))
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)

# Add label
cv2.putText(image, component, (x_min, y_min - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)

# Output directory
output_dir = "data/tmp"
os.makedirs(output_dir, exist_ok=True)

# Get the original filename without path
original_filename = os.path.basename(image_path)
output_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_with_bboxes.png")

output_path = os.path.join(output_dir, os.path.splitext(
original_filename)[0] + "_with_bboxes.png")

if cv2.imwrite(output_path, image):
print(f"Successfully saved annotated image: {output_path}")
return output_path
print("Error: Failed to save image")
return ""


def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> str:
"""Save bounding boxes information to a JSON file"""
# Output directory
output_dir = "data/tmp"
os.makedirs(output_dir, exist_ok=True)

original_filename = os.path.basename(image_path)
json_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_bboxes.json")

json_path = os.path.join(output_dir, os.path.splitext(
original_filename)[0] + "_bboxes.json")

bboxes_dict = {k: list(v) for k, v in bboxes.items()}

try:
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(bboxes_dict, f, indent=4, ensure_ascii=False)
Expand All @@ -233,13 +245,13 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path
# bboxes = {}
# current_image_path = image_path
# ark_client = Doubao(api_path) # Change your client according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path)

# image = cv2.imread(image_path)
# if image is None:
# print(f"Error: Failed to read image {image_path}")
# return bboxes
# h, w = image.shape[:2]

# for i, (component_name, prompt) in enumerate(PROMPT_LIST):
# print(f"\n=== Processing {component_name} (Step {i+1}/{len(PROMPT_LIST)}) ===")

Expand All @@ -252,36 +264,36 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path
# bbox_content = ark_client.ask(prompt, base64_image)
# print(f"Model response for {component_name}:")
# print(bbox_content)

# norm_bbox = parse_single_bbox(bbox_content, component_name)
# if norm_bbox:
# bboxes[component_name] = norm_bbox
# print(f"Successfully detected {component_name}: {norm_bbox}")

# masked_image = image_mask(current_image_path, norm_bbox)

# temp_image_path = f"data/temp_{component_name}_masked.png"
# masked_image.save(temp_image_path)
# current_image_path = temp_image_path

# print(f"Created masked image for next step: {temp_image_path}")
# else:
# print(f"Failed to detect {component_name}")

# return bboxes

# def parse_single_bbox(bbox_input: str, component_name: str) -> tuple[int, int, int, int]:
# """
# Parses a single component's bbox string and returns normalized coordinates.
# """
# print(f"Parsing bbox for {component_name}: {bbox_input}")

# try:
# if BBOX_TAG_START in bbox_input and BBOX_TAG_END in bbox_input:
# start_idx = bbox_input.find(BBOX_TAG_START) + len(BBOX_TAG_START)
# end_idx = bbox_input.find(BBOX_TAG_END)
# coords_str = bbox_input[start_idx:end_idx].strip()

# norm_coords = list(map(int, coords_str.split()))
# if len(norm_coords) == 4:
# return tuple(norm_coords)
Expand All @@ -291,7 +303,7 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path
# print(f"No bbox tags found in response for {component_name}")
# except Exception as e:
# print(f"Failed to parse bbox for {component_name}: {e}")

# return None

# def main_content_processing(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> dict[str, tuple[int, int, int, int]]:
Expand All @@ -307,16 +319,18 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path
# int(bbox[1] * h / 1000),
# int(bbox[2] * w / 1000),
# int(bbox[3] * h / 1000))


if __name__ == "__main__":
image_path = DEFAULT_IMAGE_PATH
api_path = DEFAULT_API_PATH
api_key = DEFAULT_API_KEY

print("=== Starting Simple Component Detection ===")
print(f"Input image: {image_path}")
print(f"API path: {api_path}")
client = Doubao(api_path) # Change your models according to your needs: Qwen(api_path), GPT(api_path), Gemini(api_path)
print(f"API key: {api_key}")
# You can switch model by changing the client and API key source
# For example: Qwen(get_api_key(ApiService.QWEN)), GPT(get_api_key(ApiService.GPT)), Gemini(get_api_key(ApiService.GEMINI))
client = Doubao(api_key)
bbox_content = client.ask(PROMPT_MERGE, encode_image(image_path))
print(f"Model response: {bbox_content}\n")
bboxes = parse_bboxes(bbox_content, image_path)
Expand All @@ -325,17 +339,17 @@ def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path
# print(f"Input image: {image_path}")
# print(f"API path: {api_path}")
# bboxes = sequential_component_detection(image_path, api_path)

bboxes = resolve_containment(bboxes)

if bboxes:
print(f"\n=== Detection Complete ===")
print(f"Found bounding boxes for components: {list(bboxes.keys())}")
print(f"Total components detected: {len(bboxes)}")

json_path = save_bboxes_to_json(bboxes, image_path)
draw_bboxes(image_path, bboxes)

print(f"\n=== Results ===")
for component, bbox in bboxes.items():
print(f"{component}: {bbox}")
Expand Down
Loading