From 763e10d9471310262e07b93dae3feb144a567634 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Mon, 6 Oct 2025 11:13:43 +0200 Subject: [PATCH 1/3] shell: Port shell package from cilium/cilium Port over the shell server and client from cilium/cilium so it is easier to use in other projects. Signed-off-by: Jussi Maki --- shell/client.go | 242 ++++++++++++++++++++++++++++++++++++++++++++ shell/config.go | 28 +++++ shell/const.go | 15 +++ shell/server.go | 200 ++++++++++++++++++++++++++++++++++++ shell/shell_test.go | 92 +++++++++++++++++ 5 files changed, 577 insertions(+) create mode 100644 shell/client.go create mode 100644 shell/config.go create mode 100644 shell/const.go create mode 100644 shell/server.go create mode 100644 shell/shell_test.go diff --git a/shell/client.go b/shell/client.go new file mode 100644 index 0000000..a79251c --- /dev/null +++ b/shell/client.go @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package shell + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/signal" + "strings" + "sync/atomic" + "time" + + "github.com/cilium/hive/script" + "github.com/spf13/cobra" + "golang.org/x/term" +) + +// ShellCmd constructs a cobra command for dialing a shell server. +func ShellCmd(defaultSockPath string, prompt string, printGreeting func(w io.Writer)) *cobra.Command { + var sockPath *string + cmd := &cobra.Command{ + Use: "shell [command] [args]...", + Short: "Connect to the shell", + Run: func(cmd *cobra.Command, args []string) { + cfg := Config{ + ShellSockPath: *sockPath, + } + executeShell(cfg, prompt, printGreeting, args) + }, + } + sockPath = cmd.Flags().String(ShellSockPathName, defaultSockPath, "Path to the shell UNIX socket") + return cmd +} + +var stdReadWriter = struct { + io.Reader + io.Writer +}{ + Reader: os.Stdin, + Writer: os.Stdout, +} + +func dialShell(c Config, sigs <-chan os.Signal, w io.Writer) (net.Conn, error) { + var conn net.Conn + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + for { + var err error + var d net.Dialer + conn, err = d.DialContext(ctx, "unix", c.ShellSockPath) + if err == nil { + break + } + // Dialing failed. Server might not be fully up yet. Wait a bit and retry. + select { + case <-sigs: + return nil, fmt.Errorf("interrupted") + case <-ctx.Done(): + return nil, fmt.Errorf("dialing timed out: %w", err) + case <-time.After(time.Second): + fmt.Fprintf(w, "Dialing failed: %s. Retrying...\n", err) + } + } + return conn, nil +} + +// ShellExchange sends a single command to the shell. Output is written +// to the given writer [w]. +func ShellExchange(c Config, w io.Writer, format string, args ...any) error { + conn, err := dialShell(c, nil, os.Stderr) + if err != nil { + return err + } + defer conn.Close() + _, err = fmt.Fprintf(conn, format+"\nexit\n", args...) + if err != nil { + return err + } + bio := bufio.NewReader(conn) + for { + lineBytes, isPrefix, err := bio.ReadLine() + if err != nil { + return nil + } + line := string(lineBytes) + if line == stdoutMarker || line == stderrMarker { + // Commands that write to "stdout" instead of the log show the [stdout] as + // the first line. This is useful information in tests, but not useful in + // the shell, so just skip this. + continue + } + line, ended := strings.CutSuffix(line, endMarker) + if isPrefix { + // Partial line, don't print \n yet. + _, err = fmt.Fprint(w, line) + } else { + _, err = fmt.Fprintln(w, line) + } + if err != nil { + return err + } + if ended { + return nil + } + } +} + +func executeShell(cfg Config, prompt string, printGreeting func(io.Writer), args []string) { + if len(args) > 0 { + err := ShellExchange(cfg, os.Stdout, "%s", strings.Join(args, " ")) + if err != nil { + fmt.Fprintf(os.Stdout, "error: %s\n", err) + } + } else { + os.Exit(interactiveShell(cfg, prompt, printGreeting)) + } +} + +func interactiveShell(cfg Config, prompt string, printGreeting func(w io.Writer)) int { + // Try to set the terminal to raw mode (so that cursor keys work etc.) + restore, err := script.MakeRaw(0) + if err != nil { + fmt.Fprintf(os.Stderr, "Error setting terminal to raw mode: %s\n", err) + } else { + defer restore() + } + + console := term.NewTerminal(stdReadWriter, prompt) + if width, height, err := term.GetSize(0); err == nil { + console.SetSize(width, height) + } + if printGreeting != nil { + printGreeting(console) + } + + // Listen for SIGINT to stop. + sigs := make(chan os.Signal, 1) + defer func() { + signal.Stop(sigs) + close(sigs) + }() + signal.Notify(sigs, os.Interrupt) + + // Try to dial the shell.sock. Since it takes a moment for the server to come up and this + // is meant for interactive use we'll try to be helpful and retry the dialing until + // server comes up. + conn, err := dialShell(cfg, sigs, console) + if err != nil { + fmt.Fprintf(console, "Error dialing: %s\n", err) + return 1 + } + + // Use a boolean to decide whether to redial the connection on error or whether to stop. + // This allows interrupting a long-running command with Ctrl-C and dropping back to + // the prompt. + var redial atomic.Bool + + go func() { + for range sigs { + // Ask for a redial and close the connection + redial.Store(true) + conn.Close() + } + }() + + bio := bufio.NewReader(conn) + + // Read commands from the console and send them to the server for execution. +repl: + for { + line, err := console.ReadLine() + if err != nil { + break + } + + // Send the command to the server. + if _, err = fmt.Fprintln(conn, line); err != nil { + // Failed to send. See if should try reconnecting or whether we should + // print the error and stop. + if redial.Load() { + redial.Store(false) + conn, err = dialShell(cfg, sigs, console) + if err != nil { + fmt.Fprintf(console, "Error dialing: %s\n", err) + return 1 + } + bio = bufio.NewReader(conn) + + // Try again with the new connection. + if _, err = fmt.Fprintln(conn, line); err != nil { + fmt.Fprintf(console, "Error sending: %s\n", err) + break repl + } + } else { + fmt.Fprintf(console, "Error: %s\n", err) + break repl + } + } + + // Pipe the response to the console until a line ends with the + // [endMarker]. + for { + lineBytes, isPrefix, err := bio.ReadLine() + if err != nil { + if redial.Load() { + // Redialing requested, drop back to prompt. + continue repl + } + if !errors.Is(err, io.EOF) { + fmt.Fprintf(console, "Error reading: %s\n", err) + } + break repl + } + line := string(lineBytes) + + if line == "[stdout]" || line == "[stderr]" { + // Commands that write to "stdout" instead of the log show the [stdout] as + // the first line. This is useful information in tests, but not useful in + // the shell, so just skip this. + continue + } + + line, ended := strings.CutSuffix(line, endMarker) + if isPrefix { + fmt.Fprint(console, line) + } else { + fmt.Fprintln(console, line) + } + if ended { + break + } + } + } + conn.Close() + return 0 +} diff --git a/shell/config.go b/shell/config.go new file mode 100644 index 0000000..3ef0f9e --- /dev/null +++ b/shell/config.go @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package shell + +import ( + "github.com/spf13/pflag" +) + +const ShellSockPathName = "shell-sock-path" + +var DefaultConfig = Config{ShellSockPath: ""} + +// Config is the configuration for the shell server. +type Config struct { + ShellSockPath string +} + +// Flags adds flags for Config when running the shell server Cell. +func (def Config) Flags(flags *pflag.FlagSet) { + flags.String(ShellSockPathName, def.ShellSockPath, "Path to the shell UNIX socket") +} + +// Parse the config from the flags. +func (cfg *Config) Parse(flags *pflag.FlagSet) (err error) { + cfg.ShellSockPath, err = flags.GetString(ShellSockPathName) + return err +} diff --git a/shell/const.go b/shell/const.go new file mode 100644 index 0000000..595298f --- /dev/null +++ b/shell/const.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package shell + +const ( + // endMarker marks the end of output from a command. + endMarker = "<>" + + // stdoutMarker marks the output to be from the stdout buffer. + stdoutMarker = "[stdout]" + + // stderrMarker marks the output to be from the stderr buffer. + stderrMarker = "[stderr]" +) diff --git a/shell/server.go b/shell/server.go new file mode 100644 index 0000000..48a4cb2 --- /dev/null +++ b/shell/server.go @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package shell + +import ( + "bufio" + "context" + "errors" + "fmt" + "log/slog" + "net" + "os" + "runtime" + "sync" + + "github.com/cilium/hive" + "github.com/cilium/hive/cell" + "github.com/cilium/hive/job" + "github.com/cilium/hive/script" +) + +func ServerCell(defaultSocketPath string) cell.Cell { + return cell.Module( + "shell", + "Hive debug shell", + + cell.Config(Config{ShellSockPath: defaultSocketPath}), + cell.Invoke(registerShell), + ) +} + +// defaultCmdsToInclude specify which default script commands to include. +// Most of them are for testing, so no need to clutter the shell +// with them. +var defaultCmdsToInclude = []string{ + "cat", "exec", "help", +} + +func registerShell(in hive.ScriptCmds, log *slog.Logger, lc cell.Lifecycle, jobs job.Registry, health cell.Health, c Config) { + jg := jobs.NewGroup(health, lc) + + if c.ShellSockPath == "" { + log.Info("Shell socket path not set, not starting shell server") + return + } + + cmds := in.Map() + defCmds := script.DefaultCmds() + for _, name := range defaultCmdsToInclude { + cmds[name] = defCmds[name] + } + e := script.Engine{ + Cmds: cmds, + Conds: nil, + } + jg.Add(job.OneShot("listener", shell{jg, log, &e, c}.listener)) +} + +type shell struct { + jg job.Group + log *slog.Logger + engine *script.Engine + config Config +} + +func (sh shell) listener(ctx context.Context, health cell.Health) error { + // Remove any old UNIX sock file from previous runs. + os.Remove(sh.config.ShellSockPath) + + var lc net.ListenConfig + l, err := lc.Listen(ctx, "unix", sh.config.ShellSockPath) + if err != nil { + return fmt.Errorf("failed to listen on %q: %w", sh.config.ShellSockPath, err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + <-ctx.Done() + l.Close() + wg.Done() + }() + defer wg.Wait() + + health.OK(fmt.Sprintf("Listening on %s", sh.config.ShellSockPath)) + sh.log.Info("Shell listening", "socket", sh.config.ShellSockPath) + connCount := 0 + for ctx.Err() == nil { + conn, err := l.Accept() + if err != nil { + // If context is cancelled, the listener was closed gracefully + if errors.Is(ctx.Err(), context.Canceled) { + return nil + } + return fmt.Errorf("accept failed: %w", err) + } + connID := connCount + connCount++ + + sh.jg.Add(job.OneShot( + fmt.Sprintf("shell-%d", connID), + func(ctx context.Context, h cell.Health) error { + sh.handleConn(ctx, connID, conn) + h.Close() // remove from health list + return nil + })) + } + return nil +} + +func (sh shell) handleConn(ctx context.Context, clientID int, conn net.Conn) { + sh.log.Info("Client connected", "id", clientID) + defer sh.log.Info("client disconnected", "id", clientID) + + ctx, cancel := context.WithCancel(ctx) + + // Wait for context cancellation in the background and close + // the connection if that happens. This allows teardown on + // errors or when parent context cancels. + var wg sync.WaitGroup + wg.Add(1) + go func() { + <-ctx.Done() + conn.Close() + wg.Done() + }() + defer wg.Wait() + defer cancel() + + // Catch panics to make sure the script commands can't bring the runtime down. + defer func() { + if err := recover(); err != nil { + // Log the panic and also write it to client. We keep processing + // more commands after this. + stack := make([]byte, 1024) + stack = stack[:runtime.Stack(stack, false)] + sh.log.Error("Panic in the shell handler", + "error", err, + "stacktrace", stack, + ) + fmt.Fprintf(conn, "PANIC: %s\n%s\n%s\n", err, stack, endMarker) + } + }() + + s, err := script.NewState(ctx, "/tmp", nil) + if err != nil { + sh.log.Error("NewState", "error", err) + return + } + + bio := bufio.NewReader(conn) + + // Wrap the connection into a writer that cancels the context we use to execute + // commands. This allows interrupting the command without having to have the commands + // handle write errors. + writer := interceptingWriter{ + conn: conn, + onError: func(error) { + cancel() + }, + } + + for { + bline, _, err := bio.ReadLine() + if err != nil { + break + } + line := string(bline) + switch line { + case "stop", "exit", "quit": + return + } + err = sh.engine.ExecuteLine(s, line, writer) + if err != nil { + _, err = fmt.Fprintln(writer, err) + if err != nil { + break + } + } + // Send the "end of command output" marker + _, err = fmt.Fprintln(writer, endMarker) + if err != nil { + break + } + } +} + +type interceptingWriter struct { + conn net.Conn + onError func(error) +} + +func (iw interceptingWriter) Write(buf []byte) (int, error) { + n, err := iw.conn.Write(buf) + if err != nil { + iw.onError(err) + } + return n, err +} diff --git a/shell/shell_test.go b/shell/shell_test.go new file mode 100644 index 0000000..b055a4e --- /dev/null +++ b/shell/shell_test.go @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package shell + +import ( + "context" + "flag" + "os" + "os/exec" + "path" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cilium/hive" + "github.com/cilium/hive/cell" + "github.com/cilium/hive/hivetest" + "github.com/cilium/hive/job" +) + +var client = flag.String("client", "", "Act as client to given unix socket") + +func TestMain(m *testing.M) { + flag.Parse() + if *client != "" { + interactiveShell(Config{ShellSockPath: *client}, "test> ", nil) + return + } else { + os.Exit(m.Run()) + } +} + +func fixture(t *testing.T, cfg Config) { + h := hive.New( + job.Cell, + cell.SimpleHealthCell, + cell.Provide( + func(r job.Registry, lc cell.Lifecycle, health cell.Health) job.Group { + return r.NewGroup(health, lc) + }, + ), + ServerCell(cfg.ShellSockPath), + ) + + log := hivetest.Logger(t) + require.NoError(t, + h.Start(log, context.TODO()), + "Start") + t.Cleanup(func() { + assert.NoError(t, + h.Stop(log, context.TODO()), + "Stop") + }) + + // Wait for the socket file to appear to avoid the 1s retry backoff + for range 100 { + _, err := os.Stat(cfg.ShellSockPath) + if err == nil { + break + } + time.Sleep(time.Millisecond) + } +} + +func TestShellExchange(t *testing.T) { + sock := path.Join(t.TempDir(), "shell.sock") + cfg := Config{sock} + fixture(t, cfg) + + var buf strings.Builder + err := ShellExchange(cfg, &buf, "help") + assert.NoError(t, err, "ShellExchangeWithConfig") + assert.Contains(t, buf.String(), "commands:") +} + +func TestInteractiveShell(t *testing.T) { + sock := path.Join(t.TempDir(), "shell.sock") + cfg := Config{sock} + fixture(t, cfg) + + cmd := exec.Command(os.Args[0], "-client", sock) + cmd.Stdin = strings.NewReader("help help\r\nexit\r\n") + out, err := cmd.CombinedOutput() + require.NoError(t, err, "CombinedOutput") + + require.Contains(t, string(out), "test> help help") + require.Contains(t, string(out), "log help text") +} From 6924ed592d4fbf54ff2aa65f7dce7d23b3f38d77 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Mon, 20 Oct 2025 16:58:02 +0200 Subject: [PATCH 2/3] shell: Add support for auto-completion Adapted from Dylan's prior work. Authored-by: Dylan Reimerink Signed-off-by: Dylan Reimerink Signed-off-by: Jussi Maki --- shell/client.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/shell/client.go b/shell/client.go index a79251c..d6b4232 100644 --- a/shell/client.go +++ b/shell/client.go @@ -12,6 +12,7 @@ import ( "net" "os" "os/signal" + "slices" "strings" "sync/atomic" "time" @@ -170,6 +171,7 @@ func interactiveShell(cfg Config, prompt string, printGreeting func(w io.Writer) }() bio := bufio.NewReader(conn) + console.AutoCompleteCallback = autocomplete(conn, bio) // Read commands from the console and send them to the server for execution. repl: @@ -191,6 +193,7 @@ repl: return 1 } bio = bufio.NewReader(conn) + console.AutoCompleteCallback = autocomplete(conn, bio) // Try again with the new connection. if _, err = fmt.Fprintln(conn, line); err != nil { @@ -240,3 +243,93 @@ repl: conn.Close() return 0 } + +func autocomplete(conn net.Conn, bio *bufio.Reader) func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { + var ( + suggestionIndex int + suggestionPos int = -1 + ) + return func(line string, pos int, key rune) (string, int, bool) { + switch key { + case '\t': + default: + suggestionIndex = 0 + suggestionPos = -1 + + // Only handle tab completion. + return line, pos, false + } + + // If we have not queried the server yet, or the line has changed, we need to + // query the server for suggestions. + if suggestionPos == -1 { + suggestionPos = pos + } + + if suggestionPos > len(line) { + suggestionPos = len(line) + } + + line = line[:suggestionPos] + + // If the line does not contain a space, we are still typing out the initial command. + if !strings.Contains(line, " ") { + // Ask server for suggestions of root commands. + if _, err := fmt.Fprintln(conn, "help -a "+line); err != nil { + return "", 0, false + } + } else { + cmd, args, _ := strings.Cut(line, " ") + args = strings.Replace(args, "'", "\\'", -1) // Escape single quotes for the shell. + // Ask server for suggestions for the specific command. + if _, err := fmt.Fprintf(conn, "%s --autocomplete='%s'\n", cmd, args); err != nil { + return "", 0, false + } + } + + var suggestions []string + suggestion := "" + for { + lineBytes, isPrefix, err := bio.ReadLine() + if err != nil { + // Connection closed! + return "", 0, false + } + line := string(lineBytes) + + if line == "[stdout]" || line == "[stderr]" { + // Commands that write to "stdout" instead of the log show the [stdout] as + // the first line. This is useful information in tests, but not useful in + // the shell, so just skip this. + continue + } + + line, ended := strings.CutSuffix(line, endMarker) + suggestion += line + if !isPrefix { + if suggestion != "" { + suggestions = append(suggestions, suggestion) + } + suggestion = "" + } + if ended { + break + } + } + + slices.Sort(suggestions) + + if suggestionIndex > len(suggestions)-1 { + suggestionIndex = 0 + } + + if len(suggestions) == 0 { + // No suggestions available. + return line, pos, false + } + + currentSuggestion := suggestions[suggestionIndex] + suggestionIndex++ + return currentSuggestion, len(currentSuggestion), true + } +} From 2c0f11fcf0b489bcbf4cc644e54a03aed46701f9 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Mon, 6 Oct 2025 11:14:15 +0200 Subject: [PATCH 3/3] example: Add support for the shell package Add the shell server to the example application and add the shell command. To start the example application: go run ./example In another terminal connect to the example application shell: go run ./example shell Signed-off-by: Jussi Maki --- example/main.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/example/main.go b/example/main.go index 72df34e..7e40477 100644 --- a/example/main.go +++ b/example/main.go @@ -4,6 +4,8 @@ package main import ( + "fmt" + "io" "log/slog" "os" @@ -12,13 +14,17 @@ import ( "github.com/cilium/hive" "github.com/cilium/hive/cell" "github.com/cilium/hive/job" + "github.com/cilium/hive/shell" ) +const exampleShellSocketPath = "/tmp/example-shell.sock" + var ( // Create a hive from a set of cells. Hive = hive.New( cell.SimpleHealthCell, job.Cell, + shell.ServerCell(exampleShellSocketPath), cell.Module( "example", @@ -78,7 +84,8 @@ var ( // ... // example> hive stop replCmd = &cobra.Command{ - Use: "repl", + Use: "repl", + Short: "Run the Hive repl for the example application", Run: func(_ *cobra.Command, args []string) { hive.RunRepl(Hive, os.Stdin, os.Stdout, "example> ") }, @@ -95,9 +102,21 @@ func main() { // Add the "repl" command to interactively run the application. replCmd, + + // Add the shell client command. + // + // After starting the application ("go run ./example") you can connect + // to the shell with "go run ./example shell" + shell.ShellCmd(exampleShellSocketPath, "example> ", shellGreeting), ) // And finally execute the command to parse the command-line flags and // run the hive cmd.Execute() } + +func shellGreeting(w io.Writer) { + fmt.Fprintln(w) + fmt.Fprintf(w, "... Welcome to the example application shell ...\n") + fmt.Fprintln(w) +}