diff --git a/vast.py b/vast.py index bda6c8d0..bbf3b04b 100755 --- a/vast.py +++ b/vast.py @@ -29,6 +29,7 @@ import textwrap from pathlib import Path import warnings +import platform ARGS = None TABCOMPLETE = False @@ -1464,19 +1465,118 @@ def create__env_var(args): else: print(f"Failed to create environment variable: {result.get('msg', 'Unknown error')}") +def get_ssh_key_paths(): + if platform.system() == "Windows": + base = os.environ["USERPROFILE"] + else: + base = os.path.expanduser("~") + key_path = os.path.join(base, ".ssh", "id_rsa") + pub_key_path = key_path + ".pub" + return key_path, pub_key_path + + +def ensure_ssh_key_exists(key_path, pub_key_path): + if os.path.exists(pub_key_path): + print("SSH key pair already exists, using the existing key.") + return True + + print("Generating a new RSA SSH key pair...") + try: + subprocess.run(["ssh-keygen", "-t", "rsa", "-f", key_path, "-q", "-N", ""], check=True) + return True + except FileNotFoundError: + return False + except subprocess.CalledProcessError: + print("Error occurred while generating the SSH key. Please check your setup or generate manually.") + return False + + +def install_openssh_client_linux(): + try: + print("Attempting to install OpenSSH client on Linux...") + subprocess.run(["sudo", "apt-get", "update"], check=True) + subprocess.run(["sudo", "apt-get", "install", "-y", "openssh-client"], check=True) + return True + except subprocess.CalledProcessError: + print("Failed to install openssh-client. Please install it manually.") + return False + + +def add_key_to_ssh_agent(key_path): + try: + subprocess.run(["ssh-add", key_path], check=True) + subprocess.run(["ssh-add", "-l"], check=True) + return True + except subprocess.CalledProcessError: + print("Unable to add SSH key to the agent. Make sure the ssh-agent is running.") + return False + + +def read_public_key(pub_key_path): + try: + with open(pub_key_path, "r") as f: + return f.read().strip() + except OSError: + print(f"Unable to read the public key from {pub_key_path}.") + return None + + +def generate_ssh_key_pair(): + system = platform.system() + key_path, pub_key_path = get_ssh_key_paths() + + os.makedirs(os.path.dirname(key_path), exist_ok=True) + + if system == "Windows": + if not ensure_ssh_key_exists(key_path, pub_key_path): + print("'ssh-keygen' not found. Please install the OpenSSH Client and ensure it's on your PATH.") + return None + return read_public_key(pub_key_path) + + elif system in ["Linux", "Darwin"]: + if not ensure_ssh_key_exists(key_path, pub_key_path): + if system == "Linux": + if install_openssh_client_linux(): + if not ensure_ssh_key_exists(key_path, pub_key_path): + return None + else: + return None + else: + print("'ssh-keygen' not found. Please install it manually.") + return None + + add_key_to_ssh_agent(key_path) + return read_public_key(pub_key_path) + + else: + print("Unsupported platform. Only Linux, macOS, and Windows are supported.") + return None + @parser.command( - argument("ssh_key", help="add the public key of your ssh key to your account (form the .pub file)", type=str), + argument("ssh_key", help="add the public key of your ssh key to your account (or use 'auto' to generate one automatically on Linux/MacOS)", type=str, default="auto"), usage="vastai create ssh-key ssh_key", help="Create a new ssh-key", epilog=deindent(""" - Use this command to create a new ssh key for your account. - All ssh keys are stored in your account and can be used to connect to instances they've been added to - All ssh keys should be added in rsa format + Use this command to create a new ssh key for your account. + All ssh keys are stored in your account and can be used to connect to instances they've been added to. + All ssh keys should be added in RSA format. + + Quickstart (for Linux/MacOS with 'auto'): + 1. Generates an RSA key pair (if needed) using: ssh-keygen -t rsa + 2. Loads the key into the SSH agent using: ssh-add; ssh-add -l + 3. Reads your public key from ~/.ssh/id_rsa.pub and uses it for the account. """) ) def create__ssh_key(args): + if args.ssh_key == "auto": + public_key = generate_ssh_key_pair() + if public_key is None: + return + else: + public_key = args.ssh_key + url = apiurl(args, "/ssh/") - r = http_post(args, url, headers=headers, json={"ssh_key": args.ssh_key}) + r = http_post(args, url, headers=headers, json={"ssh_key": public_key}) r.raise_for_status() print("ssh-key created {}".format(r.json()))