Skip to content

Commit 1438c6c

Browse files
committed
cp
1 parent e2de24d commit 1438c6c

File tree

18 files changed

+1459
-295
lines changed

18 files changed

+1459
-295
lines changed

metta/tools/play.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,4 @@ def invoke(self, args: dict[str, str]) -> int | None:
129129
# Print agent inventories table
130130
episode.print_agent_stats()
131131

132-
# Print agent inventories
133-
if episode.agent_inventories and episode.resource_names:
134-
console.print("\n[bold cyan]Agent Inventories:[/bold cyan]")
135-
for agent_id, inventory in enumerate(episode.agent_inventories):
136-
inv_str = ", ".join(f"{k}={v}" for k, v in sorted(inventory.items()) if v > 0)
137-
console.print(f" Agent {agent_id}: {inv_str if inv_str else '(empty)'}")
138-
139-
# Print resource totals
140-
console.print("\n[bold cyan]Resource Totals:[/bold cyan]")
141-
resource_totals: dict[str, int] = {}
142-
for inventory in episode.agent_inventories:
143-
for resource, amount in inventory.items():
144-
resource_totals[resource] = resource_totals.get(resource, 0) + amount
145-
for resource in episode.resource_names:
146-
total = resource_totals.get(resource, 0)
147-
console.print(f" {resource}: {total}")
148-
149132
return None

projects/skydeck/CLAUDE.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SkyDeck Development Guidelines
2+
3+
## SkyPilot Integration
4+
5+
**IMPORTANT**: Always use the SkyPilot API for accessing job and cluster information. Never query the SkyPilot SQLite databases directly (e.g., `~/.sky/jobs.db`, `~/.sky/state.db`).
6+
7+
### API Access
8+
9+
- API endpoint: `https://skypilot-api.softmax-research.net`
10+
- Authentication: OAuth2 cookies stored in `~/.sky/cookies.txt`
11+
- Configuration: `~/.sky/config.yaml`
12+
13+
### Why Use the API
14+
15+
1. **Centralized data**: The API provides access to managed jobs across all users and clusters
16+
2. **Up-to-date information**: The API reflects the current state of the jobs controller
17+
3. **Proper abstractions**: The API provides structured data with proper types
18+
4. **Security**: Direct database access bypasses authentication and auditing
19+
20+
## Database
21+
22+
**Database Location**: SkyDeck uses SQLite for persistent storage.
23+
24+
- **Default location**: `~/.skydeck/skydeck.db`
25+
- **Configuration**: Can be overridden with `--db-path` flag or `SKYDECK_DB_PATH` environment variable
26+
- **Schema**: Defined in `skydeck/database.py` with automatic migrations on startup
27+
28+
### Database Scripts
29+
30+
When working with the database directly:
31+
32+
```bash
33+
# Backfill checkpoint versions (example)
34+
uv run python -c "
35+
import asyncio
36+
from pathlib import Path
37+
from skydeck.backfill_versions import backfill_checkpoint_versions
38+
db_path = str(Path.home() / '.skydeck' / 'skydeck.db')
39+
asyncio.run(backfill_checkpoint_versions(db_path))
40+
"
41+
42+
# Query database directly
43+
sqlite3 ~/.skydeck/skydeck.db "SELECT COUNT(*) FROM experiments;"
44+
```
45+
46+
### Code Style
47+
48+
- Always use `uv` for pip and python operations
49+
- Imports should go at the top of the file if possible
50+
- Follow existing patterns in the codebase for consistency

projects/skydeck/discover_api.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Discover SkyPilot API endpoints."""
2+
3+
import http.cookiejar
4+
import json
5+
import urllib.request
6+
7+
from skydeck.services import ServiceEndpoints
8+
9+
10+
def load_cookies():
11+
"""Load cookies from ~/.sky/cookies.txt."""
12+
jar = http.cookiejar.MozillaCookieJar()
13+
jar.load("/Users/daveey/.sky/cookies.txt", ignore_discard=True, ignore_expires=True)
14+
return jar
15+
16+
17+
def try_endpoint(path):
18+
"""Try an API endpoint and return the response."""
19+
url = f"{ServiceEndpoints.SKYPILOT_API}{path}"
20+
jar = load_cookies()
21+
opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(jar))
22+
23+
try:
24+
req = urllib.request.Request(url)
25+
req.add_header("Accept", "application/json")
26+
response = opener.open(req, timeout=10)
27+
data = response.read().decode("utf-8")
28+
try:
29+
parsed = json.loads(data)
30+
print(f"✓ {path}")
31+
print(f" Status: {response.status}")
32+
print(f" Response: {json.dumps(parsed, indent=2)[:200]}...")
33+
return True
34+
except json.JSONDecodeError:
35+
print(f"✗ {path} - Not JSON: {data[:100]}")
36+
return False
37+
except urllib.error.HTTPError as e:
38+
print(f"✗ {path} - HTTP {e.code}: {e.reason}")
39+
return False
40+
except Exception as e:
41+
print(f"✗ {path} - Error: {e}")
42+
return False
43+
44+
45+
if __name__ == "__main__":
46+
endpoints = [
47+
# Try different API paths
48+
"/api/jobs",
49+
"/api/v1/jobs",
50+
"/api/v2/jobs",
51+
"/api/managed_jobs",
52+
"/api/queue",
53+
"/api/status",
54+
"/jobs",
55+
"/queue",
56+
"/v1/jobs",
57+
"/v1/managed_jobs",
58+
# Try specific job
59+
"/api/jobs/9588",
60+
"/api/v1/jobs/9588",
61+
"/jobs/9588",
62+
# Try listing endpoints
63+
"/api",
64+
"/api/v1",
65+
"",
66+
]
67+
68+
print("Testing SkyPilot API endpoints...")
69+
print("=" * 60)
70+
71+
for endpoint in endpoints:
72+
try_endpoint(endpoint)
73+
print()

projects/skydeck/skydeck/app.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,24 @@ async def update_experiment_starred(experiment_id: str, data: dict):
259259

260260
@app.delete("/api/experiments/{experiment_id}")
261261
async def delete_experiment(experiment_id: str):
262-
"""Delete an experiment."""
262+
"""Soft-delete an experiment."""
263263
try:
264264
await desired_state_manager.delete_experiment(experiment_id)
265265
return {"message": "Experiment deleted"}
266266
except ValueError as e:
267267
raise HTTPException(status_code=404, detail=str(e)) from e
268268

269269

270+
@app.post("/api/experiments/{experiment_id}/undelete")
271+
async def undelete_experiment(experiment_id: str):
272+
"""Restore a soft-deleted experiment."""
273+
try:
274+
await db.undelete_experiment(experiment_id)
275+
return {"message": "Experiment restored"}
276+
except Exception as e:
277+
raise HTTPException(status_code=404, detail=str(e)) from e
278+
279+
270280
@app.post("/api/experiments/{experiment_id}/state")
271281
async def update_desired_state(experiment_id: str, request: UpdateDesiredStateRequest):
272282
"""Update experiment desired state."""
@@ -284,9 +294,8 @@ async def get_experiment_status(experiment_id: str) -> ExperimentStatus:
284294
if not experiment:
285295
raise HTTPException(status_code=404, detail="Experiment not found")
286296

287-
current_job = None
288-
if experiment.current_job_id:
289-
current_job = await db.get_job(experiment.current_job_id)
297+
# Get current active job dynamically
298+
current_job = await db.get_current_job_for_experiment(experiment_id)
290299

291300
recent_jobs = await db.get_jobs_for_experiment(experiment_id, limit=10)
292301

@@ -321,7 +330,7 @@ async def update_experiment_flags(experiment_id: str, request: UpdateFlagsReques
321330
@app.get("/api/experiments/{experiment_id}/checkpoints")
322331
async def get_experiment_checkpoints(experiment_id: str, limit: int = 50):
323332
"""Get checkpoints for an experiment."""
324-
from urllib.parse import quote
333+
from .services import ObservatoryService
325334

326335
experiment = await desired_state_manager.get_experiment(experiment_id)
327336
if not experiment:
@@ -332,7 +341,7 @@ async def get_experiment_checkpoints(experiment_id: str, limit: int = 50):
332341
# Enrich checkpoints with Observatory URLs
333342
for checkpoint in checkpoints:
334343
policy_name = f"{experiment_id}.{checkpoint.epoch}"
335-
checkpoint.observatory_url = f"https://observatory.softmax-research.net/policy/{quote(policy_name, safe='')}"
344+
checkpoint.observatory_url = ObservatoryService.get_policy_web_url(policy_name)
336345

337346
return {"checkpoints": checkpoints}
338347

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Script to backfill policy versions for all checkpoints in the database."""
2+
3+
import asyncio
4+
import logging
5+
from datetime import datetime
6+
7+
from .database import Database
8+
from .services import ObservatoryService
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
13+
14+
async def backfill_checkpoint_versions(db_path: str = "skydeck.db"):
15+
"""Backfill policy versions for all checkpoints missing them.
16+
17+
Args:
18+
db_path: Path to the SQLite database
19+
"""
20+
db = Database(db_path)
21+
await db.connect()
22+
23+
try:
24+
# Get all experiments
25+
experiments = await db.get_all_experiments()
26+
logger.info(f"Found {len(experiments)} experiments")
27+
28+
total_checkpoints = 0
29+
updated_checkpoints = 0
30+
31+
for exp in experiments:
32+
logger.info(f"Processing experiment: {exp.id}")
33+
34+
# Get all checkpoints for this experiment (no limit)
35+
cursor = await db._conn.execute(
36+
"""
37+
SELECT * FROM checkpoints
38+
WHERE experiment_id = ?
39+
ORDER BY epoch DESC
40+
""",
41+
(exp.id,),
42+
)
43+
rows = await cursor.fetchall()
44+
checkpoints = [db._row_to_checkpoint(row) for row in rows]
45+
46+
total_checkpoints += len(checkpoints)
47+
logger.info(f" Found {len(checkpoints)} checkpoints")
48+
49+
# Filter checkpoints that need backfill
50+
needs_backfill = [cp for cp in checkpoints if not cp.policy_version]
51+
52+
if not needs_backfill:
53+
logger.info(" All checkpoints have versions, skipping")
54+
continue
55+
56+
logger.info(f" {len(needs_backfill)} checkpoints need version backfill")
57+
58+
# Fetch policy version once per experiment
59+
policy_version = ObservatoryService.fetch_policy_version(exp.id)
60+
observatory_url = ObservatoryService.get_policy_api_url(exp.id, limit=500)
61+
62+
if policy_version:
63+
logger.info(f" Found policy version: {policy_version}")
64+
else:
65+
logger.warning(f" Could not fetch policy version for {exp.id}")
66+
67+
# Update all checkpoints that need backfill
68+
for cp in needs_backfill:
69+
# Update the checkpoint fields
70+
cp.policy_version = policy_version
71+
cp.observatory_url = observatory_url
72+
cp.synced_at = datetime.utcnow()
73+
74+
# Save the updated checkpoint
75+
await db.save_checkpoint(cp)
76+
updated_checkpoints += 1
77+
78+
logger.info(f" Updated {len(needs_backfill)} checkpoints")
79+
80+
logger.info("\nBackfill complete:")
81+
logger.info(f" Total checkpoints: {total_checkpoints}")
82+
logger.info(f" Updated checkpoints: {updated_checkpoints}")
83+
84+
finally:
85+
await db.close()
86+
87+
88+
async def main():
89+
"""Main entry point."""
90+
await backfill_checkpoint_versions()
91+
92+
93+
if __name__ == "__main__":
94+
asyncio.run(main())

0 commit comments

Comments
 (0)