diff --git a/main.go b/main.go index db5fdc2..5c9ea1c 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,9 @@ import ( "log" "os" "path/filepath" + "regexp" "runtime" + "sort" "strings" "github.com/tatsushid/minssh/pkg/minssh" @@ -25,6 +27,51 @@ var defaultIdentityFiles = []string{ "id_ed25519", } +func getValidOptions() (validOptions map[string]map[string]string) { + validOptions = make(map[string]map[string]string) + validOptions["StrictHostKeyChecking"] = map[string]string{ + "valids": "yes or no", + "default": "yes", + } + validOptions["Password"] = map[string]string{ + "valids": "any string", + "default": "no password", + } + return +} + +func getOptionData() ( + validOptions map[string]map[string]string, + validOptionKeys []string, + optionsMsg string) { + optionsMsg = "Use `option` to specify options for which there is no separate\n" + optionsMsg += " command-line flag. This can be called multiple times.\n" + optionsMsg += "\tValid options:\n" + validOptions = getValidOptions() + validOptionKeys = make([]string, 0, len(validOptions)) + for key := range validOptions { + validOptionKeys = append(validOptionKeys, key) + } + sort.Strings(validOptionKeys) + for keyIndex := range validOptionKeys { + key := validOptionKeys[keyIndex] + optionsMsg += fmt.Sprintf("\t%s=%s, default is %s\n", + key, + validOptions[key]["valids"], + validOptions[key]["default"]) + } + return +} + +func appendValidsMsg(valids string, + option string, + msgIn string) (msgOut string) { + msgOut = msgIn + msgOut += " Valid values:\n" + msgOut += " " + valids + "\n" + return +} + type strSliceValue []string func (v *strSliceValue) Set(s string) error { @@ -93,13 +140,18 @@ func (a *app) initApp() (err error) { func (a *app) parseArgs() (err error) { var ( + options []string logPath string useOpenSSHFiles bool showVersion bool ) + validOptions, validOptionKeys, optionsMsg := getOptionData() + a.flagSet.Var((*strSliceValue)(&a.conf.IdentityFiles), "i", "use `identity_file` for public key authentication. this can be called multiple times") + a.flagSet.Var((*strSliceValue)(&options), "o", optionsMsg) a.flagSet.IntVar(&a.conf.Port, "p", 22, "specify ssh server `port`") + a.flagSet.BoolVar(&a.conf.QuietMode, "q", false, "Quiet mode. Suppresses most warning and diagnostic messages, default is false.") a.flagSet.BoolVar(&a.conf.IsSubsystem, "s", false, "treat command as subsystem") a.flagSet.StringVar(&logPath, "E", "", "specify `log_file` path. if it isn't set, it discards all log outputs") a.flagSet.BoolVar(&useOpenSSHFiles, "U", false, "use keys and known_hosts files in OpenSSH's '.ssh' directory") @@ -129,6 +181,59 @@ func (a *app) parseArgs() (err error) { } } + errorMsg := "" + for _, option := range options { + var rex = regexp.MustCompile("(\\w+)=(.*)") + data := rex.FindAllStringSubmatch(option, -1) + if len(data) > 0 { + for _, keyVal := range data { + key := keyVal[1] + val := keyVal[2] + switch key { + case "StrictHostKeyChecking": + switch val { + case "yes": + a.conf.StrictHostKeyChecking = true + case "no": + a.conf.StrictHostKeyChecking = false + default: + invalidValMsg := "Option %s has invalid value: %s\n" + invalidValMsg = appendValidsMsg( + validOptions[key]["valids"], + key, + invalidValMsg) + errorMsg += fmt.Sprintf(invalidValMsg, key, val) + } + case "Password": + a.conf.PromptUserForPassword = false + a.conf.Password = val + default: + invalidOptionMsg := "Unknown option: %s\n" + validsString := "" + for keyIndex := range validOptionKeys { + key := validOptionKeys[keyIndex] + if keyIndex > 0 { + validsString += ", " + } + validsString += key + } + invalidOptionMsg = appendValidsMsg( + validsString, + "options", + invalidOptionMsg) + errorMsg += fmt.Sprintf(invalidOptionMsg, key) + } + } + } else { + invalidSyntaxMsg := "Option %s has invalid syntax\n" + invalidSyntaxMsg += " Please specify an option as a key=value pair\n" + errorMsg += fmt.Sprintf(invalidSyntaxMsg, option) + } + } + if len(errorMsg) > 0 { + return fmt.Errorf(errorMsg) + } + if useOpenSSHFiles { for _, f := range defaultKnownHostsFiles { f = filepath.Join(a.homeDir, ".ssh", f) diff --git a/pkg/minssh/config.go b/pkg/minssh/config.go index 4317cdd..3aa31ff 100644 --- a/pkg/minssh/config.go +++ b/pkg/minssh/config.go @@ -7,15 +7,19 @@ import ( ) type Config struct { - User string - Host string - Port int - Logger *log.Logger - KnownHostsFiles []string - IdentityFiles []string - Command string - IsSubsystem bool - NoTTY bool + User string + PromptUserForPassword bool + Password string + Host string + Port int + Logger *log.Logger + StrictHostKeyChecking bool + KnownHostsFiles []string + IdentityFiles []string + Command string + QuietMode bool + IsSubsystem bool + NoTTY bool } func NewConfig() *Config { @@ -24,6 +28,8 @@ func NewConfig() *Config { Host: "", Port: 22, Logger: log.New(ioutil.Discard, "minssh ", log.LstdFlags), + PromptUserForPassword: true, + StrictHostKeyChecking: true, } } diff --git a/pkg/minssh/minssh.go b/pkg/minssh/minssh.go index 171e23d..d5fea36 100644 --- a/pkg/minssh/minssh.go +++ b/pkg/minssh/minssh.go @@ -52,42 +52,80 @@ func IsTerminal() (bool, error) { return true, nil } -func readPassword(ttyin, ttyout *os.File, prompt string) (password string, err error) { - state, err := terminal.GetState(int(ttyin.Fd())) - if err != nil { - return "", fmt.Errorf("failed to get terminal state: %s", err) +func isStdinValid() (isValid bool) { + stat, _ := os.Stdin.Stat() + if stat != nil { + isValid = true + } else { + isValid = false } + return +} - stopC := make(chan struct{}) - defer func() { - close(stopC) - }() - - go func() { - sigC := make(chan os.Signal, 1) - signal.Notify(sigC, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - select { - case <-sigC: - terminal.Restore(int(ttyin.Fd()), state) - os.Exit(1) - case <-stopC: +func directedPrintf(quietMode bool, + ttyout *os.File, + stringToPrint string) () { + if !quietMode { + if ttyout != nil { + fmt.Fprintf(ttyout, stringToPrint) + } else { + // ttyout is not open; + // so just call Printf here + fmt.Printf(stringToPrint) } - }() - - if prompt == "" { - fmt.Fprint(ttyout, "Password: ") } else { - fmt.Fprint(ttyout, prompt) + // do nothing here, since suppressing prints } +} - b, err := terminal.ReadPassword(int(ttyin.Fd())) - if err != nil { - return "", fmt.Errorf("failed to read password: %s", err) - } +func readPassword(ms *MinSSH, ttyin, ttyout *os.File, prompt string) (password string, err error) { + pwd := "" + if ms.conf.PromptUserForPassword { + state, err := terminal.GetState(int(ttyin.Fd())) + if err != nil { + return "", fmt.Errorf("failed to get terminal state: %s", err) + } - fmt.Fprint(ttyout, "\n") + stopC := make(chan struct{}) + defer func() { + close(stopC) + }() - return string(b), nil + go func() { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + select { + case <-sigC: + terminal.Restore(int(ttyin.Fd()), state) + os.Exit(1) + case <-stopC: + } + }() + + if prompt == "" { + directedPrintf(ms.conf.QuietMode, ttyout, "Password: ") + } else { + directedPrintf(ms.conf.QuietMode, ttyout, prompt) + } + + b, err := terminal.ReadPassword(int(ttyin.Fd())) + pwd = string(b) + if err != nil { + return "", fmt.Errorf("failed to read password: %s", err) + } + + directedPrintf(ms.conf.QuietMode, ttyout, "\n") + + } else { + // use password from options + pwd = ms.conf.Password + directedPrintf( + ms.conf.QuietMode, + ttyout, + "Using password from options\n") + } + + return pwd, nil } func askAddingUnknownHostKey(address string, remote net.Addr, key ssh.PublicKey) (bool, error) { @@ -207,12 +245,14 @@ func (ms *MinSSH) verifyAndAppendNew(hostname string, remote net.Addr, key ssh.P return err } - if answer, err := askAddingUnknownHostKey(hostname, remote, key); err != nil || !answer { - msg := "host key verification failed" - if err != nil { - msg += ": " + err.Error() + if ms.conf.StrictHostKeyChecking { + if answer, err := askAddingUnknownHostKey(hostname, remote, key); err != nil || !answer { + msg := "host key verification failed" + if err != nil { + msg += ": " + err.Error() + } + return fmt.Errorf(msg) } - return fmt.Errorf(msg) } f, err := os.OpenFile(ms.conf.KnownHostsFiles[0], os.O_WRONLY|os.O_APPEND, 0600) @@ -237,12 +277,17 @@ func (ms *MinSSH) verifyAndAppendNew(hostname string, remote net.Addr, key ssh.P } func (ms *MinSSH) getSigners() (signers []ssh.Signer, err error) { - ttyin, ttyout, err := openTTY() - if err != nil { - return signers, fmt.Errorf("failed to open tty: %s", err) + ttyin := (*os.File)(nil) + ttyout := (*os.File)(nil) + err = (error)(nil) + if ms.conf.PromptUserForPassword { + ttyin, ttyout, err = openTTY() + if err != nil { + return signers, fmt.Errorf("failed to open tty: %s", err) + } + defer closeTTY(ttyin, ttyout) } - defer closeTTY(ttyin, ttyout) - + for _, identityFile := range ms.conf.IdentityFiles { identityFile = os.ExpandEnv(identityFile) key, err := ioutil.ReadFile(identityFile) @@ -260,7 +305,7 @@ func (ms *MinSSH) getSigners() (signers []ssh.Signer, err error) { } continue } - password, err := readPassword(ttyin, ttyout, "password for decrypting key: ") + password, err := readPassword(ms, ttyin, ttyout, "password for decrypting key: ") if err != nil { ms.conf.Logger.Printf("failed to decrypt private key: %s\n", err) continue @@ -285,12 +330,17 @@ func (ms *MinSSH) getSigners() (signers []ssh.Signer, err error) { } func (ms *MinSSH) keyboardInteractiveChallenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { - ttyin, ttyout, err := openTTY() - if err != nil { - return answers, fmt.Errorf("failed to open tty: %s", err) + ttyin := (*os.File)(nil) + ttyout := (*os.File)(nil) + err = (error)(nil) + if ms.conf.PromptUserForPassword { + ttyin, ttyout, err = openTTY() + if err != nil { + return answers, fmt.Errorf("failed to open tty: %s", err) + } + defer closeTTY(ttyin, ttyout) } - defer closeTTY(ttyin, ttyout) - + answers = make([]string, len(questions)) var strs []string if len(questions) > 0 { @@ -301,13 +351,23 @@ func (ms *MinSSH) keyboardInteractiveChallenge(user, instruction string, questio strs = append(strs) } if len(strs) > 0 { - fmt.Fprintln(ttyout, strings.Join(strs, " ")) + directedPrintf( + ms.conf.QuietMode, + ttyout, + strings.Join(strs, " ") + "\n") } else { - fmt.Fprintf(ttyout, "Keyboard interactive challenge for %s@%s\n", ms.conf.User, ms.conf.Host) + stringToPrint := fmt.Sprintf( + "Keyboard interactive challenge for %s@%s\n", + ms.conf.User, + ms.conf.Host) + directedPrintf( + ms.conf.QuietMode, + ttyout, + stringToPrint) } } for i, q := range questions { - res, err := readPassword(ttyin, ttyout, q) + res, err := readPassword(ms, ttyin, ttyout, q) if err != nil { return answers, err } @@ -317,14 +377,26 @@ func (ms *MinSSH) keyboardInteractiveChallenge(user, instruction string, questio } func (ms *MinSSH) passwordCallback() (secret string, err error) { - ttyin, ttyout, err := openTTY() - if err != nil { - return secret, fmt.Errorf("failed to open tty: %s", err) - } - defer closeTTY(ttyin, ttyout) - - fmt.Fprintf(ttyout, "Password authentication for %s@%s\n", ms.conf.User, ms.conf.Host) - return readPassword(ttyin, ttyout, "Password: ") + ttyin := (*os.File)(nil) + ttyout := (*os.File)(nil) + err = (error)(nil) + if ms.conf.PromptUserForPassword { + ttyin, ttyout, err = openTTY() + if err != nil { + return secret, fmt.Errorf("failed to open tty: %s", err) + } + defer closeTTY(ttyin, ttyout) + } + + stringToPrint := fmt.Sprintf( + "Password authentication for %s@%s\n", + ms.conf.User, + ms.conf.Host) + directedPrintf( + ms.conf.QuietMode, + ttyout, + stringToPrint) + return readPassword(ms, ttyin, ttyout, "Password: ") } func (ms *MinSSH) Close() { @@ -478,18 +550,32 @@ func (ms *MinSSH) invokeInOutPipes() { } func (ms *MinSSH) printExitMessage(err error) { - fmt.Printf("ssh connection to %s closed ", ms.conf.Host) + stringToPrint := fmt.Sprintf( + "ssh connection to %s closed ", + ms.conf.Host) if err != nil { switch e := err.(type) { case *ssh.ExitMissingError: - fmt.Printf("but remote didn't send exit status: %s\n", e) + stringToPrint += fmt.Sprintf( + "but remote didn't send exit status: %s\n", e) case *ssh.ExitError: - fmt.Printf("with error: %s\n", e) + stringToPrint += fmt.Sprintf( + "with error: %s\n", e) default: - fmt.Printf("with unknown error: %s\n", err) + stringToPrint += fmt.Sprintf( + "with unknown error: %s\n", err) } + // always print errors, so call Printf here + fmt.Printf(stringToPrint) } else { - fmt.Println("successfully") + stringToPrint += "successfully\n" + // no tty object here, so pass in nil + ttyout := (*os.File)(nil) + // conditionally print success messages + directedPrintf( + ms.conf.QuietMode, + ttyout, + stringToPrint) } } @@ -503,7 +589,15 @@ func (ms *MinSSH) Run() (err error) { } func (ms *MinSSH) RunCommand() error { - ms.sess.Stdin = os.Stdin + if isStdinValid() { + ms.sess.Stdin = os.Stdin + } else { + // if stdin is not valid, + // pass nil to sess; + // this avoids sess returning an + // invalid handle error + ms.sess.Stdin = nil + } ms.sess.Stdout = os.Stdout ms.sess.Stderr = os.Stderr @@ -528,7 +622,15 @@ func (ms *MinSSH) RunCommand() error { } func (ms *MinSSH) RunSubsystem() error { - ms.sess.Stdin = os.Stdin + if isStdinValid() { + ms.sess.Stdin = os.Stdin + } else { + // if stdin is not valid, + // pass nil to sess; + // this avoids sess returning an + // invalid handle error + ms.sess.Stdin = nil + } ms.sess.Stdout = os.Stdout ms.sess.Stderr = os.Stderr