Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use anyhow::{bail, Context};
use std::fs;
use std::io::Read;
use std::path::PathBuf;
use std::process::Command;

use tlparse::{
// New reusable library API for multi-rank landing generation
Expand All @@ -15,10 +16,11 @@ use tlparse::{
};

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(author, version, about = "Parse TORCH_LOG logs produced by PyTorch torch.compile\n\nUsage modes:\n tlparse <path> Parse existing trace logs\n tlparse run -- <command> Run command with tracing enabled", long_about = None)]
#[command(propagate_version = true)]
pub struct Cli {
path: PathBuf,
/// Input path (log file or directory)
path: Option<PathBuf>,
/// Parse most recent log
#[arg(long)]
latest: bool,
Expand Down Expand Up @@ -67,8 +69,59 @@ pub struct Cli {
}

fn main() -> anyhow::Result<()> {
// Check if this is the "run" subcommand before clap parses
// This is needed because clap doesn't handle "run -- cmd args" well with other positional args
let raw_args: Vec<String> = std::env::args().collect();

// Find the position of "run" if present
if let Some(run_pos) = raw_args.iter().position(|s| s == "run") {
// Only treat as run subcommand if it's before any "--" separator
let dash_dash_pos = raw_args.iter().position(|s| s == "--");
if dash_dash_pos.map_or(true, |dd| run_pos < dd) {
// Extract everything after "run", filtering out leading "--" if present
let rest: Vec<String> = raw_args.iter().skip(run_pos + 1).cloned().collect();
let command: Vec<String> = if rest.first().map(|s| s.as_str()) == Some("--") {
rest.into_iter().skip(1).collect()
} else {
rest
};

// Parse options that come before "run"
let pre_run_args: Vec<String> = raw_args[..run_pos].to_vec();
let cli = match Cli::try_parse_from(pre_run_args) {
Ok(c) => c,
Err(_) => {
// Fall back to defaults if parsing fails
Cli {
path: None,
latest: false,
out: PathBuf::from("tl_out"),
overwrite: false,
strict: false,
strict_compile_id: false,
no_browser: false,
custom_header_html: String::new(),
verbose: false,
plain_text: false,
export: false,
inductor_provenance: false,
all_ranks_html: false,
serve: false,
port: None,
}
}
};
return handle_run_command(command, &cli);
}
}

let cli = Cli::parse();

// Default behavior: parse an existing log file/directory
let Some(ref input_path) = cli.path else {
bail!("No input path provided.\n\nUsage:\n tlparse <path> Parse existing trace logs\n tlparse run -- <command> Run command with tracing enabled");
};

// Early validation of incompatible flags
if cli.all_ranks_html && cli.latest {
bail!("--latest cannot be used with --all-ranks-html");
Expand All @@ -78,7 +131,6 @@ fn main() -> anyhow::Result<()> {
let open_browser = !cli.no_browser && !cli.serve;

let path = if cli.latest {
let input_path = cli.path;
// Path should be a directory
if !input_path.is_dir() {
bail!(
Expand All @@ -98,7 +150,7 @@ fn main() -> anyhow::Result<()> {
};
last_modified_file.path()
} else {
cli.path
input_path.clone()
};

let config = ParseConfig {
Expand Down Expand Up @@ -132,6 +184,71 @@ fn main() -> anyhow::Result<()> {
Ok(())
}

/// Handle the `run` subcommand: execute a command with TORCH_TRACE set and parse the output
fn handle_run_command(command: Vec<String>, cli: &Cli) -> anyhow::Result<()> {
if command.is_empty() {
bail!("No command provided. Usage: tlparse run -- <command> [args...]");
}

// Create a temporary directory for traces
let trace_dir = std::env::temp_dir().join(format!("tlparse_trace_{}", std::process::id()));
fs::create_dir_all(&trace_dir)?;

println!("Running command with TORCH_TRACE={}", trace_dir.display());

// Build and run the command
let mut child = Command::new(&command[0])
.args(&command[1..])
.env("TORCH_TRACE", &trace_dir)
.spawn()
.with_context(|| format!("Failed to execute command: {}", command[0]))?;

let status = child.wait()?;

if !status.success() {
eprintln!(
"Command exited with status: {}",
status.code().unwrap_or(-1)
);
}

// Check if any traces were generated
let has_traces = fs::read_dir(&trace_dir)?
.flatten()
.any(|e| e.path().is_file());

if !has_traces {
// Clean up empty trace directory
let _ = fs::remove_dir_all(&trace_dir);
bail!("No trace files were generated. Make sure your PyTorch code triggers compilation (e.g., uses torch.compile).");
}

println!("Parsing traces from {}", trace_dir.display());

// --serve implies --no-browser
let open_browser = !cli.no_browser && !cli.serve;

let config = ParseConfig {
strict: cli.strict,
strict_compile_id: cli.strict_compile_id,
custom_parsers: Vec::new(),
custom_header_html: cli.custom_header_html.clone(),
verbose: cli.verbose,
plain_text: cli.plain_text,
export: cli.export,
inductor_provenance: cli.inductor_provenance,
};

// Use --latest to parse the most recent trace file in the directory
handle_one_rank(&config, trace_dir, true, cli.out.clone(), open_browser, cli.overwrite)?;

if cli.serve {
serve_directory(&cli.out, cli.port)?;
}

Ok(())
}

/// Create the output directory
fn setup_output_directory(out_path: &PathBuf, overwrite: bool) -> anyhow::Result<()> {
if out_path.exists() {
Expand Down
Loading