diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md index cf0bfb2..b506414 100644 --- a/.github/ISSUE_TEMPLATE/bug.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -2,7 +2,7 @@ name: Bug about: Template for a problem with the library's behaviour title: "" -labels: Bug +type: Bug assignees: '' --- diff --git a/.github/ISSUE_TEMPLATE/task.md b/.github/ISSUE_TEMPLATE/task.md index 7ea27b9..622779f 100644 --- a/.github/ISSUE_TEMPLATE/task.md +++ b/.github/ISSUE_TEMPLATE/task.md @@ -2,7 +2,7 @@ name: Task about: Template for the smallest actionable chunk title: '' -labels: Task +type: Task assignees: '' --- diff --git a/.github/workflows/CI_test_suite.yml b/.github/workflows/CI_test_suite.yml index d5e18af..2a0f718 100644 --- a/.github/workflows/CI_test_suite.yml +++ b/.github/workflows/CI_test_suite.yml @@ -2,26 +2,29 @@ name: CI Test Suite on: push: - branches: [ "main" ] + branches: + - "main" + - "henk/*" pull_request: branches: [ "main" ] jobs: - SystemTests: + SystemTestsSequential: runs-on: ubuntu-latest + if: ${{ github.event_name == 'pull_request' }} steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version-file: go.mod - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.12' - name: System tests dependencies run: xargs -a system_test_requirements.txt sudo apt-get install @@ -31,9 +34,9 @@ jobs: run: pip install -r python_requirements.txt working-directory: test_suite - - name: Run system tests + - name: Run system tests sequentially id: system-test - run: ./system_tests.sh -c 0 + run: ./system_tests.sh -b working-directory: test_suite continue-on-error: true @@ -47,7 +50,42 @@ jobs: if: ${{ steps.system-test.outcome == 'failure' }} run: exit 1 - IntegrationTests: + SystemTestsParallel: + runs-on: ubuntu-latest + if: ${{ github.event_name == 'push' }} + steps: + - uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build Dockerfile to run system tests in parallel + uses: docker/build-push-action@v6 + with: + file: test_suite/Dockerfile + load: true + tags: system_tests:latest + build-args: BRANCH=${{ github.ref_name }} + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Run system tests in parallel + id: system-test + run: GITHUB_ACTION=true ./system_tests.sh -t 4 + working-directory: test_suite + continue-on-error: true + + - name: Upload system test logs + uses: actions/upload-artifact@v4 + with: + name: system-test-logs + path: test_suite/system_test_logs/ + + - name: Fail job if system test failed (for clarity in GitHub UI) + if: ${{ steps.system-test.outcome == 'failure' }} + run: exit 1 + + PerformanceTests: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -57,6 +95,99 @@ jobs: with: go-version: '1.22' + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: System tests dependencies + run: xargs -a system_test_requirements.txt sudo apt-get install + working-directory: test_suite + + - name: Performance tests dependencies + run: pip install -r python_requirements.txt + working-directory: test_suite + + - name: Run performance test with varying bitrate + id: system-test + run: ./system_tests.sh -b -p -L performance + working-directory: test_suite + continue-on-error: true + + - name: Upload performance test graphs + uses: actions/upload-artifact@v4 + with: + name: performance-test-graphs + path: test_suite/system_test_logs/performance/*/*.png + + - name: Upload performance test data + uses: actions/upload-artifact@v4 + with: + name: performance-test-data + path: test_suite/system_test_logs/performance/*/performance_test_data.json + + - name: Upload full performance test logs + uses: actions/upload-artifact@v4 + with: + name: performance-test-logs + path: test_suite/system_test_logs/ + + - name: Fail job if performance test failed (for clarity in GitHub UI) + if: ${{ steps.system-test.outcome == 'failure' }} + run: exit 1 + + - name: Download artifact from target branch head + id: download_target_branch + if: ${{ github.event_name == 'pull_request' }} + uses: dawidd6/action-download-artifact@v9 + with: + branch: ${{ github.base_ref }} + name: performance-test-data + path: ./test_suite + skip_unpack: true + workflow_conclusion: "" + continue-on-error: true + + - name: Download artifact from previous commit + id: download_previous_commit + if: ${{ github.event_name == 'push' }} + uses: dawidd6/action-download-artifact@v9 + with: + commit: ${{ github.event.before }} + name: performance-test-data + path: ./test_suite + skip_unpack: true + workflow_conclusion: "" + continue-on-error: true + + - name: Stop job and give warning if downloading previous artifact failed + if: ${{ steps.download_target_branch.outcome == 'failure' || steps.download_previous_commit.outcome == 'failure' }} + run: | + echo "# ⚠️ Could not make performance comparison" >> $GITHUB_STEP_SUMMARY + echo "Downloading performance test data of target branch head/previous commit failed. See the corresponding PerformanceTests job step for details" >> $GITHUB_STEP_SUMMARY + + - name: Unzip artifact + working-directory: test_suite + if: ${{ steps.download_target_branch.outcome == 'success' || steps.download_previous_commit.outcome == 'success' }} + run: unzip performance-test-data.zip -d previous_performance + + - name: Compare current and artifact performance, and redirect results to job step summary + id: performance-comparison + working-directory: test_suite + if: ${{ steps.download_target_branch.outcome == 'success' || steps.download_previous_commit.outcome == 'success' }} + run: python compare_performance.py previous_performance/ system_test_logs/performance/ > $GITHUB_STEP_SUMMARY + + + IntegrationTests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Run integration tests id: integration-test run: go test -coverpkg=./... -coverprofile cover.out -v ./... diff --git a/.github/workflows/golangci-lint-main.yml b/.github/workflows/golangci-lint-main.yml new file mode 100644 index 0000000..64e8b55 --- /dev/null +++ b/.github/workflows/golangci-lint-main.yml @@ -0,0 +1,23 @@ +name: golangci-lint +on: + push: + branches: + - main + +permissions: + contents: read + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.60 + args: -E stylecheck,revive,gocritic,gofumpt diff --git a/.github/workflows/golangci-lint-pr.yml b/.github/workflows/golangci-lint-pr.yml new file mode 100644 index 0000000..fac0d8b --- /dev/null +++ b/.github/workflows/golangci-lint-pr.yml @@ -0,0 +1,23 @@ +name: golangci-lint +on: + pull_request: + +permissions: + contents: read + pull-requests: read + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.60 + args: -E stylecheck,revive,gocritic,gofumpt + only-new-issues: true diff --git a/cmd/control_server/main.go b/cmd/control_server/main.go index 92a544a..8bbeab3 100644 --- a/cmd/control_server/main.go +++ b/cmd/control_server/main.go @@ -6,10 +6,6 @@ import ( "errors" "flag" "fmt" - "github.com/edup2p/common/types/control" - "github.com/edup2p/common/types/control/controlhttp" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/relay" "io" "log" "log/slog" @@ -23,13 +19,16 @@ import ( "sync" "syscall" "time" + + "github.com/edup2p/common/types/control" + "github.com/edup2p/common/types/control/controlhttp" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/relay" ) var ( - //dev = flag.Bool("dev", false, "run in localhost development mode (overrides -a)") addr = flag.String("a", ":443", "server HTTP/HTTPS listen address, in form \":port\", \"ip:port\", or for IPv6 \"[ip]:port\". If the IP is omitted, it defaults to all interfaces. Serves HTTPS if the port is 443 and/or -certmode is manual, otherwise HTTP.") configPath = flag.String("c", "", "config file path") - //stunPort = flag.Int("stun-port", stunserver.DefaultPort, "The UDP port on which to serve STUN. The listener is bound to the same IP (if any) as specified in the -a flag.") publicFacingBaseString = flag.String("u", "", "public facing base URL (required)") publicFacingBase *url.URL @@ -44,7 +43,7 @@ var ( func main() { h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: programLevel, - //AddSource: true, + // AddSource: true, }) slog.SetDefault(slog.New(h)) programLevel.Set(-8) @@ -92,9 +91,11 @@ func main() { mux.Handle("/", handleStaticHTML(ToverSokControlDefaultHTML)) - mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) - io.WriteString(w, "User-agent: *\nDisallow: /\n") + if _, err := io.WriteString(w, "User-agent: *\nDisallow: /\n"); err != nil { + slog.Error("could not write robots.txt", "err", err) + } })) mux.Handle("/generate_204", http.HandlerFunc(serverCaptivePortalBuster)) @@ -117,7 +118,7 @@ func main() { Addr: *addr, Handler: mux, // TODO - //ErrorLog: slog.NewLogLogger(), + // ErrorLog: slog.NewLogLogger(), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -125,7 +126,9 @@ func main() { go func() { <-ctx.Done() - httpsrv.Shutdown(ctx) + if err := httpsrv.Shutdown(ctx); err != nil { + slog.Error("could not shutdown control server", "err", err) + } }() // TODO setup TLS with autocert? @@ -134,7 +137,7 @@ func main() { err = httpsrv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("control: error %s", err) + log.Fatalf("control: error %s", err) //nolint:gocritic } } @@ -149,7 +152,7 @@ type ControlServer struct { func (cs *ControlServer) HandleAuthRequest(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { - http.Redirect(w, r, "/auth/land", 302) + http.Redirect(w, r, "/auth/land", http.StatusFound) return } @@ -165,16 +168,15 @@ func (cs *ControlServer) HandleAuthRequest(w http.ResponseWriter, r *http.Reques return } - http.Redirect(w, r, "/auth/success", 302) + http.Redirect(w, r, "/auth/success", http.StatusFound) } else { // Fail - http.Redirect(w, r, "/auth/fail", 302) + http.Redirect(w, r, "/auth/fail", http.StatusFound) } - } func (cs *ControlServer) OnSessionCreate(id control.SessID, cid control.ClientID) { - println("OnSessionCreate") + slog.Info("OnSessionCreate", "id", id, "cid", cid) if cs.isKnown(key.NodePublic(cid)) { go func() { @@ -186,31 +188,30 @@ func (cs *ControlServer) OnSessionCreate(id control.SessID, cid control.ClientID return } - url, _ := url.Parse(string("/auth/land?session=" + id)) - if err := cs.server.SendAuthURL(id, publicFacingBase.ResolveReference(url).String()); err != nil { + redirectURL, _ := url.Parse(string("/auth/land?session=" + id)) + if err := cs.server.SendAuthURL(id, publicFacingBase.ResolveReference(redirectURL).String()); err != nil { slog.Error("error sending auth URL", "id", id, "err", err) } } -func (cs *ControlServer) OnSessionResume(id control.SessID, id2 control.ClientID) { - println("OnSessionResume") - return // noop +func (cs *ControlServer) OnSessionResume(sess control.SessID, cid control.ClientID) { + slog.Info("OnSessionResume", "sess", sess, "cid", cid) } -func (cs *ControlServer) OnDeviceKey(id control.SessID, key string) { - println("OnDeviceKey") - return // noop +func (cs *ControlServer) OnDeviceKey(sess control.SessID, deviceKey string) { + slog.Info("OnDeviceKey", "sess", sess, "deviceKey", deviceKey) } -func (cs *ControlServer) OnSessionFinalize(id control.SessID, id2 control.ClientID) (netip.Prefix, netip.Prefix) { - println("OnSessionFinalize") +func (cs *ControlServer) OnSessionFinalize(sess control.SessID, cid control.ClientID) (netip.Prefix, netip.Prefix, time.Time) { + slog.Info("OnSessionFinalize", "sess", sess, "cid", cid) + + ip4, ip6 := cs.getIPs(key.NodePublic(cid)) - return cs.getIPs(key.NodePublic(id2)) + return ip4, ip6, time.Now().Add(time.Hour * 24 * 7) } -func (cs *ControlServer) OnSessionDestroy(id control.SessID, id2 control.ClientID) { - println("OnSessionDestroy") - return // noop +func (cs *ControlServer) OnSessionDestroy(sess control.SessID, cid control.ClientID) { + slog.Info("OnSessionDestroy", "sess", sess, "cid", cid) } func LoadServer(ctx context.Context) *ControlServer { @@ -248,7 +249,9 @@ func (cs *ControlServer) loadExistingNodes() { continue } - if err := cs.server.UpsertVisibilityPair(client, client2, control.VisibilityPair{}); err != nil { + if err := cs.server.UpsertVisibilityPair(client, client2, control.VisibilityPair{ + MDNS: true, + }); err != nil { panic(err) } } @@ -261,7 +264,9 @@ func (cs *ControlServer) addNewNode(node key.NodePublic) { continue } - if err := cs.server.UpsertVisibilityPair(control.ClientID(node), control.ClientID(node2), control.VisibilityPair{}); err != nil { + if err := cs.server.UpsertVisibilityPair(control.ClientID(node), control.ClientID(node2), control.VisibilityPair{ + MDNS: true, + }); err != nil { panic(err) } } @@ -371,14 +376,16 @@ func handleStaticHTML(doc string) http.HandlerFunc { } } -func sendStaticHTML(doc string, w http.ResponseWriter, r *http.Request) { +func sendStaticHTML(doc string, w http.ResponseWriter, _ *http.Request) { browserHeaders(w) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) - io.WriteString(w, doc) + if _, err := io.WriteString(w, doc); err != nil { + slog.Error("failed to write static HTML page", "error", err) + } } const ToverSokControlDefaultHTML = ` @@ -488,6 +495,7 @@ func loadConfig() Config { return writeNewConfig() case err != nil: log.Fatal(err) + //goland:noinspection GoUnreachableCode panic("unreachable") default: var cfg Config @@ -507,14 +515,14 @@ func writeNewConfig() Config { } func writeConfig(cfg Config, path string) { - if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o777); err != nil { log.Fatal(err) } b, err := json.MarshalIndent(cfg, "", "\t") if err != nil { log.Fatal(err) } - if err := os.WriteFile(path, b, 0600); err != nil { + if err := os.WriteFile(path, b, 0o600); err != nil { log.Fatal(err) } } diff --git a/cmd/dev_client/main.go b/cmd/dev_client/main.go index e4c1f4b..8e2a154 100644 --- a/cmd/dev_client/main.go +++ b/cmd/dev_client/main.go @@ -7,8 +7,20 @@ import ( "errors" "flag" "fmt" + "log" + "log/slog" + "math" + "net/netip" + "os" + "path/filepath" + "runtime/pprof" + "strconv" + "strings" + "sync" + "time" + "github.com/abiosoft/ishell/v2" - "github.com/edup2p/common/ext_wg" + "github.com/edup2p/common/extwg" "github.com/edup2p/common/toversok" "github.com/edup2p/common/toversok/actors" "github.com/edup2p/common/types" @@ -19,22 +31,12 @@ import ( "github.com/edup2p/common/usrwg" "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl" - "log" - "log/slog" - "math" - "net/netip" - "os" - "path/filepath" - "runtime/pprof" - "strconv" - "strings" - "sync" ) var ( programLevel = new(slog.LevelVar) // Info by default - wgCtrl *ext_wg.WGCtrl + wgCtrl *extwg.WGCtrl usrWg *usrwg.UserSpaceWireGuardHost wg toversok.WireGuardHost @@ -43,9 +45,6 @@ var ( privKey *key.NodePrivate - //ip4 *netip.Prefix - //ip6 *netip.Prefix - fakeControl StokControl properControl toversok.DefaultControlHost usedControl toversok.ControlHost @@ -74,7 +73,9 @@ func main() { if err != nil { log.Fatal(err) } - pprof.StartCPUProfile(f) + if err := pprof.StartCPUProfile(f); err != nil { + panic(err) + } defer pprof.StopCPUProfile() } @@ -97,7 +98,7 @@ func main() { logCmd.AddCmd(&ishell.Cmd{ Name: "info", Help: "set log level to info", - Func: func(c *ishell.Context) { + Func: func(_ *ishell.Context) { programLevel.Set(slog.LevelInfo) }, }) @@ -105,7 +106,7 @@ func main() { logCmd.AddCmd(&ishell.Cmd{ Name: "debug", Help: "set log level to debug", - Func: func(c *ishell.Context) { + Func: func(_ *ishell.Context) { programLevel.Set(slog.LevelDebug) }, }) @@ -113,7 +114,7 @@ func main() { logCmd.AddCmd(&ishell.Cmd{ Name: "trace", Help: "set log level to trace", - Func: func(c *ishell.Context) { + Func: func(_ *ishell.Context) { programLevel.Set(-8) }, }) @@ -126,9 +127,6 @@ func main() { shell.AddCmd(pcCmd()) shell.AddCmd(fcCmd()) - //shell.AddCmd(tsCmd()) - //shell.AddCmd(ctrlCmd()) - shell.Run() if engine != nil { @@ -196,12 +194,13 @@ func keyCmd() *ishell.Cmd { line = c.Args[0] } - if p, err := key.UnmarshalPrivate(line); err != nil { + p, err := key.UnmarshalPrivate(line) + if err != nil { c.Err(err) return - } else { - privKey = p } + + privKey = p }, }) @@ -242,7 +241,6 @@ func getOrGenerateKey(file string, c *ishell.Context) (key.NodePrivate, error) { } data, err := os.ReadFile(file) - if err != nil { if os.IsNotExist(err) { c.Println(fmt.Sprintf("%s does not exist, generating new key...", file)) @@ -254,7 +252,7 @@ func getOrGenerateKey(file string, c *ishell.Context) (key.NodePrivate, error) { return k, fmt.Errorf("failed to marshal private key: %w", err) } - if err := os.WriteFile(file, jsonData, 0644); err != nil { + if err := os.WriteFile(file, jsonData, 0o644); err != nil { return k, fmt.Errorf("failed to write private key to file: %w", err) } @@ -288,7 +286,7 @@ func pcCmd() *ishell.Cmd { c.AddCmd(&ishell.Cmd{ Name: "use", Help: "start using the proper control", - Func: func(c *ishell.Context) { + Func: func(_ *ishell.Context) { usedControl = &properControl }, }) @@ -305,12 +303,13 @@ func pcCmd() *ishell.Cmd { line = c.Args[0] } - if p, err := key.UnmarshalControlPublic(line); err != nil { + p, err := key.UnmarshalControlPublic(line) + if err != nil { c.Err(err) return - } else { - properControl.Key = *p } + + properControl.Key = *p }, }) @@ -383,15 +382,12 @@ func fcCmd() *ishell.Cmd { c := &ishell.Cmd{ Name: "fc", Help: "fake controlhost variables and handling", - //Func: func(c *ishell.Context) { - // c.Println("fake control:", fakeControl) - //}, } c.AddCmd(&ishell.Cmd{ Name: "use", Help: "start using the proper control", - Func: func(c *ishell.Context) { + Func: func(_ *ishell.Context) { usedControl = &fakeControl }, }) @@ -411,7 +407,7 @@ func fcCmd() *ishell.Cmd { var ( err error peerKey *key.NodePublic - relay int64 + relayID int64 session key.SessionPublic ip4 netip.Addr ip6 netip.Addr @@ -423,7 +419,7 @@ func fcCmd() *ishell.Cmd { c.Err(err) return } - if relay, err = strconv.ParseInt(c.Args[1], 10, 64); err != nil { + if relayID, err = strconv.ParseInt(c.Args[1], 10, 64); err != nil { c.Err(err) return } @@ -433,24 +429,22 @@ func fcCmd() *ishell.Cmd { // We (semi-intentionally) break compatibility with any main network because of this. session = [32]byte(*peerKey) - if ip4, err = netip.ParseAddr(c.Args[2]); err != nil { + ip4, err = netip.ParseAddr(c.Args[2]) + + if err != nil { c.Err(err) return - } else { - if !ip4.Is4() { - c.Err(errors.New("ip4 isnt ipv4")) - return - } + } else if !ip4.Is4() { + c.Err(errors.New("ip4 isnt ipv4")) + return } if ip6, err = netip.ParseAddr(c.Args[3]); err != nil { c.Err(err) return - } else { - if !ip6.Is6() { - c.Err(errors.New("ip6 isnt ipv6")) - return - } + } else if !ip6.Is6() { + c.Err(errors.New("ip6 isnt ipv6")) + return } for _, e := range c.Args[4:] { @@ -465,7 +459,7 @@ func fcCmd() *ishell.Cmd { if err = fakeControl.addPeer(PeerDef{ Key: *peerKey, - HomeRelayID: relay, + HomeRelayID: relayID, SessionKey: session, Endpoints: endpoints, VIPs: toversok.VirtualIPs{ @@ -478,65 +472,6 @@ func fcCmd() *ishell.Cmd { }, }) - //peerCmd.AddCmd(&ishell.Cmd{ - // Name: "update", - // Aliases: []string{"u"}, - // Help: "update a peer: -r [relay] -e [endpoint,...]", - // Func: func(c *ishell.Context) { - // if len(c.Args) == 0 { - // c.Err(errors.New("did not define peer key")) - // return - // } - // - // peerKey, err := key.UnmarshalPublic(c.Args[0]) - // - // if err != nil { - // c.Err(fmt.Errorf("error parsing peer key: %w", err)) - // return - // } - // - // fs := flag.NewFlagSet("peer-update", flag.ContinueOnError) - // - // r := fs.Int64("r", math.MaxInt64, "relay (int64)") - // endpoints := fs.String("e", "", "endpoints (comma-seperated IPs)") - // - // if err := fs.Parse(c.Args[1:]); err != nil { - // c.Err(fmt.Errorf("could not parse flags: %w", err)) - // return - // } - // - // pu := toversok.PeerUpdate{ - // Key: *peerKey, - // } - // - // if *r != math.MaxInt64 { - // pu.HomeRelayId = gonull.NewNullable(*r) - // } - // - // if *endpoints != "" { - // as := *endpoints - // - // aps := make([]netip.AddrPort, 0) - // - // for _, addr := range strings.Split(as, ",") { - // a, err := netip.ParseAddrPort(addr) - // if err != nil { - // c.Err(err) - // return - // } - // - // aps = append(aps, a) - // } - // - // pu.Endpoints = gonull.NewNullable(aps) - // } - // - // if err = engine.Handle(pu); err != nil { - // c.Err(err) - // } - // }, - //}) - peerCmd.AddCmd(&ishell.Cmd{ Name: "delete", Aliases: []string{"del", "d"}, @@ -602,10 +537,10 @@ func fcCmd() *ishell.Cmd { ID: id, Key: *relayKey, Domain: *domain, - //IPs: gonull.Nullable[[]netip.Addr]{}, - //STUNPort: gonull.Nullable[uint16]{}, - //HTTPSPort: gonull.Nullable[uint16]{}, - //HTTPPort: gonull.Nullable[uint16]{}, + // IPs: gonull.Nullable[[]netip.Addr]{}, + // STUNPort: gonull.Nullable[uint16]{}, + // HTTPSPort: gonull.Nullable[uint16]{}, + // HTTPPort: gonull.Nullable[uint16]{}, IsInsecure: *insecure, } @@ -747,7 +682,7 @@ func wgCmd() *ishell.Cmd { device = names[choice] } - wgCtrl = ext_wg.NewWGCtrl(client, device) + wgCtrl = extwg.NewWGCtrl(client, device) wg = wgCtrl @@ -771,27 +706,25 @@ func wgCmd() *ishell.Cmd { Name: "init", Help: "Perform Init() on the wg configurator. wg init ", Func: func(c *ishell.Context) { - if len(c.Args) < 2 { + switch { + case len(c.Args) < 2: c.Err(errors.New("usage: privkey addr4 addr6")) return - } else if wg == nil { + case wg == nil: c.Err(errors.New("wg not setup")) - } else { + default: privkeyStr := c.Args[0] addr4Str := c.Args[1] addr6Str := c.Args[2] - privkeySlice, err := hex.DecodeString(privkeyStr) if err != nil { c.Err(err) return } else if len(privkeySlice) != key.Len { - c.Err(errors.New(fmt.Sprintf("unexpected key length, expected 32, got %d", len(privkeySlice)))) + c.Err(fmt.Errorf("unexpected key length, expected 32, got %d", len(privkeySlice))) return } - privkey := key.NodePrivateFrom((key.NakedKey)(privkeySlice)) - addr4, err := netip.ParsePrefix(addr4Str) if err != nil { c.Err(err) @@ -800,7 +733,6 @@ func wgCmd() *ishell.Cmd { c.Err(errors.New("first argument is not ipv4 address/cidr")) return } - addr6, err := netip.ParsePrefix(addr6Str) if err != nil { c.Err(err) @@ -809,13 +741,11 @@ func wgCmd() *ishell.Cmd { c.Err(errors.New("second argument is not ipv6 address/cidr")) return } - wgC, err = wg.Controller(privkey, addr4, addr6) if err != nil { c.Err(err) return } - c.Println("wg controller:", wgC) } }, @@ -825,7 +755,6 @@ func wgCmd() *ishell.Cmd { } func enCmd() *ishell.Cmd { - c := &ishell.Cmd{ Name: "en", Help: "toversok engine and subcommands", @@ -850,13 +779,15 @@ func enCmd() *ishell.Cmd { Func: func(c *ishell.Context) { var err error - if usedControl == nil { + switch { + case usedControl == nil: err = errors.New("no control host set") - } else if wg == nil { + case wg == nil: err = errors.New("wg is not set") - } else if privKey == nil { + case privKey == nil: err = errors.New("key is not set") } + if err != nil { c.Err(err) return @@ -871,20 +802,13 @@ func enCmd() *ishell.Cmd { } ctx, ccc := context.WithCancelCause(context.Background()) - //opts := toversok.EngineOptions{ - // Ctx: ctx, - // Ccc: ccc, - // PrivKey: key.UnveilPrivate(*privKey), - // ExtBindPort: engineExtPort, - // WG: wg, - // FW: nil, - //} fw := &StokFirewall{} e, err := toversok.NewEngine(ctx, wg, fw, usedControl, engineExtPort, *privKey) if err != nil { c.Err(err) + ccc(err) return } @@ -895,7 +819,7 @@ func enCmd() *ishell.Cmd { c.AddCmd(&ishell.Cmd{Name: "start", Help: "start the engine", Func: func(c *ishell.Context) { if engine != nil { - err := engine.Start() + _, err := engine.Start() if err != nil { c.Err(err) } @@ -961,6 +885,10 @@ func (s *StokControl) IPv6() netip.Prefix { return *s.ip6 } +func (s *StokControl) Expiry() time.Time { + return time.Time{} +} + func (s *StokControl) UpdateEndpoints(endpoints []netip.AddrPort) error { slog.Info("called UpdateEndpoints", "endpoints", endpoints) @@ -973,6 +901,10 @@ func (s *StokControl) UpdateHomeRelay(i int64) error { return nil } +func (s *StokControl) Context() context.Context { + return context.Background() +} + func (s *StokControl) InstallCallbacks(callbacks ifaces.ControlCallbacks) { s.callback = callbacks @@ -993,7 +925,7 @@ func (s *StokControl) InstallCallbacks(callbacks ifaces.ControlCallbacks) { } } -func (s *StokControl) CreateClient(parentCtx context.Context, getNode func() *key.NodePrivate, getSess func() *key.SessionPrivate, login types.LogonCallback) (ifaces.ControlSession, error) { +func (s *StokControl) CreateClient(context.Context, func() *key.NodePrivate, func() *key.SessionPrivate, types.LogonCallback) (ifaces.ControlSession, error) { return s, nil } diff --git a/cmd/mdns_monitor/main.go b/cmd/mdns_monitor/main.go new file mode 100644 index 0000000..b84a4dd --- /dev/null +++ b/cmd/mdns_monitor/main.go @@ -0,0 +1,110 @@ +package main + +import ( + "context" + "fmt" + "log" + "net" + "net/netip" + "time" + + "github.com/sethvargo/go-limiter/memorystore" + "golang.org/x/net/dns/dnsmessage" +) + +func walkInterfaces() { + ift, err := net.Interfaces() + if err != nil { + log.Fatal(err) + } + for _, ifi := range ift { + isLoopBack := ifi.Flags&net.FlagLoopback != 0 + isPtP := ifi.Flags&net.FlagPointToPoint != 0 + + fmt.Printf("iface %s: lo(%t) ptp(%t)\n", ifi.Name, isLoopBack, isPtP) + } +} + +func main() { + // this code is specific to macos, for now + + walkInterfaces() + + IP := "224.0.0.251:5353" + // IP := "[ff02::fb]:5353" + + ua := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(IP)) + + iface, err := net.InterfaceByName("lo0") + if err != nil { + log.Fatal(err) + } + + bind, err := net.ListenMulticastUDP("udp4", iface, ua) + if err != nil { + log.Fatal(err) + } + + fmt.Println("got multicast udp") + + store, err := memorystore.New(&memorystore.Config{ + // Number of tokens allowed per interval. + Tokens: 1, + + // Interval until tokens reset. + Interval: 20 * time.Second, + + SweepInterval: 1 * time.Minute, + SweepMinTTL: 1 * time.Minute, + }) + if err != nil { + log.Fatal(err) + } + + buf := make([]byte, 1<<16) + + QUBit := uint16(1 << 15) + + for { + n, ap, err := bind.ReadFromUDPAddrPort(buf) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("read %d bytes from %s\n", n, ap.String()) + + data := buf[:n] + + msg := dnsmessage.Message{} + if err = msg.Unpack(data); err != nil { + log.Printf("Error unpacking DNS message: %s\n", err) + continue + } + + _, _, _, ok, err := store.Take(context.Background(), msg.GoString()) + if err != nil { + log.Fatal(err) + } + + if !ok { + log.Println("message rate limited") + continue + } + + questions := msg.Questions + + msg.Questions = []dnsmessage.Question{} + + fmt.Printf("got mdns: %s\n", msg.GoString()) + + for _, q := range questions { + isQU := uint16(q.Class)&QUBit != 0 + + if isQU { + fmt.Printf("found QU: %s\n", q.GoString()) + } else { + fmt.Printf("found QM: %s\n", q.GoString()) + } + } + } +} diff --git a/cmd/mdns_test/main.go b/cmd/mdns_test/main.go new file mode 100644 index 0000000..19a8a4b --- /dev/null +++ b/cmd/mdns_test/main.go @@ -0,0 +1,427 @@ +package main + +import ( + "errors" + "fmt" + "log" + "log/slog" + "net" + "net/netip" + "os" + "sync" + "syscall" + "time" + + "github.com/edup2p/common/types" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var multicastIface string + +func init() { + ifaces, err := net.Interfaces() + if err != nil { + panic(fmt.Errorf("could not list network interfaces: %w", err)) + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback != 0 { + multicastIface = iface.Name + } + } +} + +//nolint:unused +var ( + MDNSPort uint16 = 5353 + ip4MDNSBroadcastBare = netip.MustParseAddr("224.0.0.251") + ip6MDNSBroadcastBare = netip.MustParseAddr("ff02::fb") + + ip4MDNSUnspecifiedAP = netip.AddrPortFrom(netip.IPv4Unspecified(), MDNSPort) + ip6MDNSUnspecifiedAP = netip.AddrPortFrom(netip.IPv6Unspecified(), MDNSPort) + + ip4MDNSBroadcastAP = netip.AddrPortFrom(ip4MDNSBroadcastBare, MDNSPort) + ip6MDNSBroadcastAP = netip.AddrPortFrom(ip6MDNSBroadcastBare, MDNSPort) + + ip4MDNSLoopBackAP = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), MDNSPort) + ip4MDNSLoopBackAPAlt = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.2"), MDNSPort) + ip6MDNSLoopBackAP = netip.AddrPortFrom(netip.IPv6Loopback(), MDNSPort) + ip6MDNSLoopBackAPAlt = netip.AddrPortFrom(netip.MustParseAddr("::2"), MDNSPort) +) + +const bit15 = 2 << 14 + +func main() { + if len(os.Args) < 2 { + println("Usage: mdns_test <.local name>") + os.Exit(1) + } + + name := os.Args[1] + ".local." + + dnsName, err := dnsmessage.NewName(name) + if err != nil { + panic(fmt.Errorf("failed to make DNS name: %w", err)) + } + allServicesName := dnsmessage.MustNewName("_services._dns-sd._udp.local.") + + nameQM := dnsmessage.Question{ + Name: dnsName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + } + + servicesQM := dnsmessage.Question{ + Name: allServicesName, + Type: dnsmessage.TypePTR, + Class: dnsmessage.ClassINET, + } + + ml4, p4, err := makeIPv4MDNSListener() + if err != nil { + panic(fmt.Errorf("failed to make ipv4 mdns listener: %w", err)) + } + defer ml4.Close() + + ml6, p6, err := makeIPv6MDNSListener() + if err != nil { + panic(fmt.Errorf("failed to make ipv6 mdns listener: %w", err)) + } + defer ml6.Close() + + u4, err := net.DialUDP("udp4", nil, net.UDPAddrFromAddrPort(ip4MDNSLoopBackAP)) + if err != nil { + panic(fmt.Errorf("failed to make ipv4 unicast listener: %w", err)) + } + defer u4.Close() + + u6, err := net.DialUDP("udp6", nil, net.UDPAddrFromAddrPort(ip6MDNSLoopBackAP)) + if err != nil { + panic(fmt.Errorf("failed to make ipv6 unicast listener: %w", err)) + } + defer u6.Close() + + var respMu sync.Mutex + var responses []*dnsmessage.Message + + appendResponse := func(msg *dnsmessage.Message) { + respMu.Lock() + defer respMu.Unlock() + responses = append(responses, msg) + } + + go func() { + buf := make([]byte, 1<<16) + + for { + n, cm, ap, err := p6.ReadFrom(buf) + if err != nil { + panic(fmt.Errorf("failed to read from ipv6 mdns listener: %w", err)) + } + + var dst net.IP + + if cm != nil && cm.Dst != nil { + dst = cm.Dst + } + + slog.Info("received ipv6 packet", "from", ap.String(), "len", n, "dst", dst) + + msg := new(dnsmessage.Message) + + if err := msg.Unpack(buf[:n]); err != nil { + slog.Error("could not unpack ipv6 packet into mdns", "error", err) + continue + } + + appendResponse(msg) + } + }() + + go func() { + buf := make([]byte, 1<<16) + + for { + n, cm, ap, err := p4.ReadFrom(buf) + if err != nil { + panic(fmt.Errorf("failed to read from ipv4 mdns listener: %w", err)) + } + + var dst net.IP + + if cm != nil && cm.Dst != nil { + dst = cm.Dst + } + + slog.Info("received ipv4 packet", "from", ap.String(), "len", n, "dst", dst) + + msg := new(dnsmessage.Message) + + if err := msg.Unpack(buf[:n]); err != nil { + slog.Error("could not unpack ipv4 packet into mdns", "error", err) + continue + } + + appendResponse(msg) + } + }() + + listener := func(name string, conn *net.UDPConn) { + buf := make([]byte, 1<<16) + + for { + n, ap, err := conn.ReadFromUDPAddrPort(buf) + if err != nil { + panic(fmt.Errorf(name+": failed to read: %w", err)) + } + + slog.Info(name+": received packet", "from", ap.String(), "len", n) + + msg := new(dnsmessage.Message) + + if err := msg.Unpack(buf[:n]); err != nil { + slog.Error(name+": could not unpack packet into mdns", "error", err) + continue + } + + appendResponse(msg) + } + } + + go listener("uni4", u4) + go listener("uni6", u6) + + nameQU := nameQM + // unicast-response + nameQU.Class |= bit15 + + servicesQU := servicesQM + servicesQU.Class |= bit15 + + questions := []dnsmessage.Question{ + // nameQM, + // nameQU, + servicesQM, + servicesQU, + } + + var queries []*dnsmessage.Message + + for _, q := range questions { + queries = append(queries, makeQuery(q)) + } + + type writeTo struct { + conn types.UDPConn + name string + to *netip.AddrPort + } + + toWrite := []writeTo{ + {conn: ml6, to: &ip6MDNSBroadcastAP, name: "ml6bc"}, + {conn: ml4, to: &ip4MDNSBroadcastAP, name: "ml4bc"}, + {conn: ml6, to: &ip6MDNSLoopBackAP, name: "ml6lo"}, + {conn: ml4, to: &ip4MDNSLoopBackAP, name: "ml4lo"}, + {conn: u4, name: "u4"}, + {conn: u6, name: "u6"}, + } + + qna := make(map[*dnsmessage.Message]map[string][]*dnsmessage.Message) + + for _, q := range queries { + current := make(map[string][]*dnsmessage.Message) + qna[q] = current + + for _, w := range toWrite { + doWrite(w.conn, w.name, w.to, q) + + time.Sleep(1 * time.Second) + + respMu.Lock() + current[w.name] = responses + responses = nil + respMu.Unlock() + } + } + + processQNA(qna) +} + +func processQNA(m map[*dnsmessage.Message]map[string][]*dnsmessage.Message) { + println("\n\n\n") + slog.Info("Printing QNA result") + + for query, result := range m { + println("\n") + slog.Info("Printing for query") + debugMDNS(query) + println() + + for name, responses := range result { + slog.Info("Printing responses for name", "name", name) + + for _, msg := range responses { + debugMDNS(msg) + } + + println() + } + } +} + +func doWrite(c types.UDPConn, name string, to *netip.AddrPort, msg *dnsmessage.Message) { + q, err := msg.Pack() + if err != nil { + log.Fatal(fmt.Errorf("%s: could not pack message: %w", name, err)) + } + + if to == nil { + if _, err := c.Write(q); err != nil { + log.Println(fmt.Errorf("%s: failed to write query: %w", name, err)) + } + } else { + if _, err := c.WriteToUDPAddrPort(q, *to); err != nil { + log.Println(fmt.Errorf("%s: failed to write query to addrport: %w", name, err)) + } + } +} + +func makeQuery(q dnsmessage.Question) *dnsmessage.Message { + return &dnsmessage.Message{ + Header: dnsmessage.Header{}, + Questions: []dnsmessage.Question{q}, + } +} + +func makeIPv4MDNSListener() (types.UDPConn, *ipv4.PacketConn, error) { + ua := net.UDPAddrFromAddrPort(ip4MDNSBroadcastAP) + + conn, err := net.ListenUDP("udp4", ua) + if err != nil { + return nil, nil, fmt.Errorf("ListenUDP error: %w", err) + } + + p4 := ipv4.NewPacketConn(conn) + + ift, err := net.Interfaces() + if err != nil { + return nil, nil, fmt.Errorf("cannot get interfaces: %w", err) + } + for _, ifi := range ift { + if ifi.Flags&net.FlagUp != 0 && ifi.Flags&net.FlagPointToPoint == 0 { + if err := p4.JoinGroup(&ifi, &net.UDPAddr{IP: ip4MDNSBroadcastBare.AsSlice()}); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + slog.Warn("p4 multicast JoinGroup failed", "err", err, "iface", ifi.Name) + } + } + } + + if loop, err := p4.MulticastLoopback(); err == nil { + if !loop { + if err := p4.SetMulticastLoopback(true); err != nil { + return nil, nil, fmt.Errorf("cannot set multicast loopback: %w", err) + } + slog.Info("Multicast Loopback enabled") + } else { + slog.Info("Multicast Loopback was enabled") + } + } else { + return nil, nil, fmt.Errorf("cannot get MulticastLoopback: %w", err) + } + + ifi, err := net.InterfaceByName(multicastIface) + if err != nil { + panic(err) + } + + if err := p4.SetMulticastInterface(ifi); err != nil { + return nil, nil, fmt.Errorf("cannot set multicast interface: %w", err) + } + + if err := p4.SetTTL(255); err != nil { + return nil, nil, fmt.Errorf("cannot set TTL: %w", err) + } + if err := p4.SetMulticastTTL(255); err != nil { + return nil, nil, fmt.Errorf("cannot set Multicast TTL: %w", err) + } + + if err = p4.SetControlMessage(ipv4.FlagDst, true); err != nil { + slog.Warn("cannot set control message dstflag", "err", err) + } + + return conn, p4, nil +} + +func makeIPv6MDNSListener() (types.UDPConn, *ipv6.PacketConn, error) { + ua := net.UDPAddrFromAddrPort(ip6MDNSBroadcastAP) + + conn, err := net.ListenUDP("udp6", ua) + if err != nil { + return nil, nil, fmt.Errorf("ListenUDP error: %w", err) + } + + p6 := ipv6.NewPacketConn(conn) + + ift, err := net.Interfaces() + if err != nil { + return nil, nil, fmt.Errorf("cannot get interfaces: %w", err) + } + for _, ifi := range ift { + if ifi.Flags&net.FlagUp != 0 && ifi.Flags&net.FlagPointToPoint == 0 { + if err := p6.JoinGroup(&ifi, &net.UDPAddr{IP: ip6MDNSBroadcastBare.AsSlice()}); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + slog.Warn("p6 multicast JoinGroup failed", "err", err, "iface", ifi.Name) + } + } + } + + if loop, err := p6.MulticastLoopback(); err == nil { + if !loop { + if err := p6.SetMulticastLoopback(true); err != nil { + return nil, nil, fmt.Errorf("cannot set multicast loopback: %w", err) + } + slog.Info("Multicast Loopback enabled") + } else { + slog.Info("Multicast Loopback was enabled") + } + } else { + return nil, nil, fmt.Errorf("cannot get MulticastLoopback: %w", err) + } + + ifi, err := net.InterfaceByName(multicastIface) + if err != nil { + panic(err) + } + + if err := p6.SetMulticastInterface(ifi); err != nil { + return nil, nil, fmt.Errorf("cannot set multicast interface: %w", err) + } + + if err = p6.SetControlMessage(ipv6.FlagDst, true); err != nil { + slog.Warn("cannot set control message dstflag", "err", err) + } + + return conn, p6, nil +} + +func debugMDNS(msg *dnsmessage.Message) { + slog.Info("debugMDNS: TXID", "txid", msg.ID) + + for _, q := range msg.Questions { + slog.Info( + "debugMDNS: Q", + "txid", msg.ID, + "name", q.Name, + "type", q.Type.GoString(), + "class", q.Class.GoString(), + ) + } + for _, a := range msg.Answers { + slog.Info( + "debugMDNS: A", + "txid", msg.ID, + "header", a.Header.GoString(), + "body", a.Body.GoString(), + ) + } +} diff --git a/cmd/relay_server/main.go b/cmd/relay_server/main.go index 3338c9a..1f030d4 100644 --- a/cmd/relay_server/main.go +++ b/cmd/relay_server/main.go @@ -6,10 +6,6 @@ import ( "errors" "flag" "fmt" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/relay" - "github.com/edup2p/common/types/relay/relayhttp" - stunserver "github.com/edup2p/common/types/stun" "io" "log" "log/slog" @@ -22,6 +18,11 @@ import ( "strings" "syscall" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/relay" + "github.com/edup2p/common/types/relay/relayhttp" + stunserver "github.com/edup2p/common/types/stun" ) var ( @@ -45,9 +46,6 @@ const ToverSokRelayDefaultHTML = ` func main() { flag.Parse() - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - if *dev { *addr = "127.0.0.1:3340" log.Printf("Running in dev mode.") @@ -71,12 +69,15 @@ func main() { log.Fatalf("could not parse stun-combined addrport: %v", err) } - stunServer := stunserver.NewServer(ctx) - go stunServer.ListenAndServe(ap) - - // TODO add STUN here + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() - // TODO continue here + stunServer := stunserver.NewServer(ctx) + go func() { + if err := stunServer.ListenAndServe(ap); err != nil { + slog.Error("stun server listen error", "err", err) + } + }() cfg := loadConfig() @@ -88,27 +89,29 @@ func main() { mux.Handle("/relay", relayhttp.ServerHandler(server)) - mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) - io.WriteString(w, ToverSokRelayDefaultHTML) + if _, err := io.WriteString(w, ToverSokRelayDefaultHTML); err != nil { + slog.Error("failed to write default HTML response", "err", err) + } })) - mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) - io.WriteString(w, "User-agent: *\nDisallow: /\n") + if _, err := io.WriteString(w, "User-agent: *\nDisallow: /\n"); err != nil { + slog.Error("failed to write robots.txt", "err", err) + } })) mux.Handle("/generate_204", http.HandlerFunc(serverCaptivePortalBuster)) httpsrv := &http.Server{ Addr: *addr, Handler: mux, - // TODO - //ErrorLog: slog.NewLogLogger(), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -116,16 +119,18 @@ func main() { go func() { <-ctx.Done() - httpsrv.Shutdown(ctx) + if err := httpsrv.Shutdown(ctx); err != nil { + slog.Error("failed to shutdown server", "err", err) + } }() - // TODO setup TLS with autocert + // TODO setup TLS with autocert: https://github.com/eduP2P/relay-server/issues/2 slog.Info("relay: serving", "addr", *addr) err = httpsrv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("relay: error %s", err) + log.Fatalf("relay: error %s", err) //nolint:gocritic } } @@ -183,6 +188,7 @@ func loadConfig() Config { return writeNewConfig() case err != nil: log.Fatal(err) + //goland:noinspection GoUnreachableCode panic("unreachable") default: var cfg Config @@ -194,7 +200,7 @@ func loadConfig() Config { } func writeNewConfig() Config { - if err := os.MkdirAll(filepath.Dir(*configPath), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(*configPath), 0o777); err != nil { log.Fatal(err) } cfg := newConfig() @@ -202,7 +208,7 @@ func writeNewConfig() Config { if err != nil { log.Fatal(err) } - if err := os.WriteFile(*configPath, b, 0600); err != nil { + if err := os.WriteFile(*configPath, b, 0o600); err != nil { log.Fatal(err) } return cfg diff --git a/cmd/stun_client/main.go b/cmd/stun_client/main.go index 14fa4ad..9f2bdd9 100644 --- a/cmd/stun_client/main.go +++ b/cmd/stun_client/main.go @@ -1,12 +1,13 @@ package main import ( - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/stun" "log" "net" "net/netip" "os" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/stun" ) func main() { diff --git a/docs/mdns.md b/docs/mdns.md new file mode 100644 index 0000000..3c14362 --- /dev/null +++ b/docs/mdns.md @@ -0,0 +1,29 @@ +# mDNS Notes + +[mDNS](https://en.wikipedia.org/wiki/Multicast_DNS) (multicast DNS) is defined in +[RFC 6763](https://datatracker.ietf.org/doc/html/rfc6762) as, essentially, UDP DNS packets sent to broadcast addresses +`224.0.0.251` and `FF02::FB` on port `5353`. + +On top of that, a new `UNICAST-RESPONSE` (`"QU"`) bit is added to the query section, which can be parsed as `QCLASS` +`2^15`, and a `CACHE-FLUSH` bit on every resource (answer/additional/authority) record, which can be parsed as `RRCLASS` +`2^15`. + +Exact implementation differs per operating system, but as a rule of thumb; + +- mDNS (like most broadcast packets) aren't sent over `PPP` (point to point) classes of networks, which Wireguard is. +- Linux needs an additional system component to enable mDNS, such as "Avahi". +- MacOS, Linux, and Windows have differently covering implementations; + f.e. Windows doesn't allow unicast queries, while macOS does. +- While mDNS works via loopback, some operating systems have quirks with how they work, and only fully work mDNS via " + regular" LAN interfaces. + +## Intercepting mDNS + +Because of the above limitations (no mDNS on PPP, etc.), we cannot intercept mDNS packets via TUN (level 3, IP), +and would have to listen to them on the regular interfaces, by listening on port `5353` (and fight with the system +listener to have it share its port), +grab mDNS packets, filter them (to prevent noise from the local LAN), transform them (to point to the right IP address +over the Wireguard interface), send them over to interested parties, and then inject them. + +This essentially makes mDNS packets get wiretapped, and "appear out of thin air" at the recipient, +which should tie it together. \ No newline at end of file diff --git a/ext_wg/wgctrl.go b/extwg/wgctrl.go similarity index 92% rename from ext_wg/wgctrl.go rename to extwg/wgctrl.go index 91b152a..5737d14 100644 --- a/ext_wg/wgctrl.go +++ b/extwg/wgctrl.go @@ -1,7 +1,16 @@ -package ext_wg +package extwg import ( + "errors" "fmt" + "log" + "log/slog" + "net" + "net/netip" + "runtime" + "strings" + "sync" + "github.com/edup2p/common/toversok" "github.com/edup2p/common/types" "github.com/edup2p/common/types/key" @@ -9,12 +18,6 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "log/slog" - "net" - "net/netip" - "runtime" - "strings" - "sync" ) // A wireguard configurator by the help of wgtools shell commands. @@ -53,8 +56,12 @@ func NewWGCtrl(client *wgctrl.Client, device string) *WGCtrl { } func (w *WGCtrl) Reset() error { + var errs []error + for _, m := range w.localMapping { - m.conn.Close() + if err := m.conn.Close(); err != nil { + errs = append(errs, err) + } } maps.Clear(w.localMapping) @@ -66,7 +73,11 @@ func (w *WGCtrl) Reset() error { ReplacePeers: true, Peers: []wgtypes.PeerConfig{}, }); err != nil { - return fmt.Errorf("error resetting wg device: %w", err) + errs = append(errs, fmt.Errorf("error resetting wg device: %w", err)) + } + + if len(errs) > 0 { + return fmt.Errorf("errors while wg device: %w", errors.Join(errs...)) } return nil @@ -130,7 +141,6 @@ func (w *WGCtrl) Controller(privateKey key.NodePrivate, addr4, addr6 netip.Prefi var device *wgtypes.Device device, err = w.client.Device(w.name) - if err != nil { return nil, err } @@ -217,6 +227,15 @@ func (w *WGCtrl) GetStats(publicKey key.NodePublic) (*toversok.WGStats, error) { }, nil } +func (w *WGCtrl) GetInterface() *net.Interface { + i, err := net.InterfaceByName(w.name) + if err != nil { + log.Println("cannot find interface ", w.name, ":", err) + return nil + } + return i +} + func (w *WGCtrl) ensureLocalConn(peer key.NodePublic) *mapping { m, ok := w.localMapping[peer] @@ -246,7 +265,6 @@ func (w *WGCtrl) rebindMapping(m *mapping) error { func (w *WGCtrl) bindLocal() *mapping { conn, err := w.getWGConn(nil) - if err != nil { panic(fmt.Sprintf("error when first binding to wgport: %s", err)) } @@ -255,7 +273,7 @@ func (w *WGCtrl) bindLocal() *mapping { } func (w *WGCtrl) getWGConn(fromPort *uint16) (*net.UDPConn, error) { - var laddr *net.UDPAddr = nil + var laddr *net.UDPAddr if fromPort != nil { laddr = net.UDPAddrFromAddrPort( diff --git a/go.mod b/go.mod index f6eb0c5..eac7946 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,18 @@ module github.com/edup2p/common go 1.22 require ( - github.com/LukaGiorgadze/gonull v1.2.0 github.com/abiosoft/ishell/v2 v2.0.2 github.com/dblohm7/wingoes v0.0.0-20240801171404-fc12d7c70140 github.com/go-ole/go-ole v1.3.0 - go.mongodb.org/mongo-driver v1.15.0 + github.com/google/gopacket v1.1.19 + github.com/sethvargo/go-limiter v1.0.0 + github.com/stretchr/testify v1.9.0 go4.org/mem v0.0.0-20220726221520-4f986261bf13 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.17.0 + golang.org/x/crypto v0.33.0 golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f - golang.org/x/sys v0.15.0 + golang.org/x/net v0.33.0 + golang.org/x/sys v0.30.0 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -24,7 +26,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.12.0 // indirect github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.12 // indirect @@ -32,8 +34,6 @@ require ( github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect - golang.org/x/net v0.10.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 21910c9..097b178 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,12 @@ -github.com/LukaGiorgadze/gonull v1.2.0 h1:I+/pHqr9dySqf6A4agJazrFA8XlrUohqdb10nFIaxJU= -github.com/LukaGiorgadze/gonull v1.2.0/go.mod h1:iGbXOBV6y4VkT14x//F3yZiIxe1ylZYor05pZb0/9TM= github.com/abiosoft/ishell v2.0.0+incompatible h1:zpwIuEHc37EzrsIYah3cpevrIc8Oma7oZPxr03tlmmw= github.com/abiosoft/ishell v2.0.0+incompatible/go.mod h1:HQR9AqF2R3P4XXpMpI0NAzgHf/aS6+zVXRj14cVk9qg= github.com/abiosoft/ishell/v2 v2.0.2 h1:5qVfGiQISaYM8TkbBl7RFO6MddABoXpATrsFbVI+SNo= github.com/abiosoft/ishell/v2 v2.0.2/go.mod h1:E4oTCXfo6QjoCart0QYa5m9w4S+deXs/P/9jA77A9Bs= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db h1:CjPUSXOiYptLbTdr1RceuZgSFDQ7U15ITERUGrUORx8= github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db/go.mod h1:rB3B4rKii8V21ydCbIzH5hZiCQE7f5E9SzUb/ZZx530= +github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -19,8 +19,12 @@ github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BMXYYRWT github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:rZfgFAXFS/z/lEd6LJmf9HVZ1LkgYiHx5pHhV5DR16M= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= @@ -33,43 +37,53 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/dblohm7/wingoes v0.0.0-20240801171404-fc12d7c70140/go.mod h1:SUxUaAK/0UG5lYyZR1L1nC4AaYYvSSYTWQSH3FPcxKU= -github.com/fatih/color v1.12.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= -github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:rZfgFAXFS/z/lEd6LJmf9HVZ1LkgYiHx5pHhV5DR16M= -github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sethvargo/go-limiter v1.0.0 h1:JqW13eWEMn0VFv86OKn8wiYJY/m250WoXdrjRV0kLe4= +github.com/sethvargo/go-limiter v1.0.0/go.mod h1:01b6tW25Ap+MeLYBuD4aHunMrJoNO5PVUFdS9rac3II= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -go.mongodb.org/mongo-driver v1.15.0 h1:rJCKC8eEliewXjZGf0ddURtl7tTVy1TK3bfl0gkUSLc= -go.mongodb.org/mongo-driver v1.15.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= +github.com/tc-hib/winres v0.2.1/go.mod h1:C/JaNhH3KBvhNKVbvdlDWkbMDO9H4fKKDaN7/07SSuk= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/image v0.12.0 h1:w13vZbU4o5rKOFFR8y7M+c4A5jXDC0uXTdHYRP8X2DQ= +golang.org/x/image v0.12.0/go.mod h1:Lu90jvHG7GfemOIcldsh9A2hS01ocl6oNO7ype5mEnk= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= @@ -78,7 +92,10 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= +gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= diff --git a/test_suite/CHANGELOG.md b/test_suite/CHANGELOG.md new file mode 100644 index 0000000..45a528e --- /dev/null +++ b/test_suite/CHANGELOG.md @@ -0,0 +1,37 @@ +# Changelog + +In this file, the test suite features that have been made possible thanks to [funding from NLnet](./README.md#funding) are documented. + + +## Parallel system tests (April 11, 2025) + +### Added +- [Dockerfile](Dockerfile) that installs all requirements to run the test suite's system tests, clones the repository state at the head of a specified branch, and build the eduP2P client, control server and relay server test binaries. +- Docker Engine as new requirement in [system test requirements](README.md#system-test-specific-requirements). +- `-t` flag in [system_tests.sh](system_tests.sh) to run the system test in parallel with the specified amount of threads. Each thread is a Docker container in which a portion of the system tests is executed. +- SystemTestsParallel job in [test suite CI workflow](../.github/workflows/CI_test_suite.yml) that builds the Dockerfile with caching and runs the system tests with the `-t` flag. +- Explanation of the `-t` flag, motivation for using Docker and reason why parallel system tests currently do not speed up CI runs in the [system test documentation](README.md#system-tests). +- Log level `trace` (most detailed) in [test_client/main.go](test_client/main.go). + +### Changed +- Building of eduP2P binaries in [system_tests.sh](system_tests.sh) now only happens when explicitly providing the new `-b` flag to save time when the binaries have already been built (e.g. in Dockerfile or on local machine). +- Optimizations in [test_client/setup_client.sh](test_client/setup_client.sh), such as smaller sleep durations to make the system tests run faster. + +### Fixed +- Bug in [visualize_performance_tests.py](visualize_performance_tests.py): quotation marks inside a format string caused error in some Python versions. +- Handshake error between eduP2P peers in the system tests caused by both peers initializing a handshake simultaneously. Fixed by desynchronizing the peers with a conditional sleep in [test_client/setup_client.sh](test_client/setup_client.sh). + +## Repeated performance tests (March 7, 2025) + +### Added +- `-r` flag in [performance_tests.sh](performance_tests.sh) to repeat the same performance test multiple times and aggregate the results of each repetition by taking their average. +- Explanation of the `-r` flag in the [performance test documentation](./README.md#performance-tests). +- Report on how aggregating the performance test results can improve their reliability in the [performance test results](./README.md#consistency-of-results). + +## Simulating network delay (March 4, 2025) + +### Added +- `-d` flag in [system_tests.sh](system_tests.sh) to add artificial network delay in the system tests. +- New value `delay` for the `-k` flag in [performance_tests.sh](performance_tests.sh) to add variable artificial network delay during the performance tests. +- Explanation of the delay variable in the [performance test documentation](./README.md#performance-tests). +- Report on how the delay affects eduP2P network performance in the [performance test results](./README.md#results-with-varying-one-way-delay). \ No newline at end of file diff --git a/test_suite/Dockerfile b/test_suite/Dockerfile new file mode 100644 index 0000000..d6c6447 --- /dev/null +++ b/test_suite/Dockerfile @@ -0,0 +1,37 @@ +# Git branch +ARG BRANCH="main" + +# Stage that clones repository, to be copied by next stage for cache optimization +FROM alpine AS get-requirements +ARG BRANCH + +# Clone repository, prevent cache from using old version by adding current version to Dockerfile +RUN apk add --no-cache git +ADD https://api.github.com/repos/eduP2P/common/git/refs/heads/${BRANCH} version.json +RUN git clone -b ${BRANCH} --single-branch https://github.com/eduP2P/common + +# Stage that makes optimal use of cache by first copying and installing requirements from previous stage, and only cloning repository afterwards +FROM golang:1.22 AS run-system-tests +ARG BRANCH + +# Command-line requirements +COPY --from=get-requirements /common/test_suite/system_test_requirements.txt /go/system_test_requirements.txt +RUN apt-get update &&\ + apt-get -y install sudo &&\ + xargs -a system_test_requirements.txt sudo apt-get -y install + +# Go packages +COPY --from=get-requirements /common/go.mod /go/src/go.mod +WORKDIR /go/src +RUN go mod download + +# Clone repository, prevent cache from using old version by adding current version to Dockerfile +WORKDIR /go +ADD https://api.github.com/repos/eduP2P/common/git/refs/heads/${BRANCH} version.json +RUN git clone -b ${BRANCH} --single-branch https://github.com/eduP2P/common + +# Build eduP2P binaries +WORKDIR /go/common/test_suite +RUN for binary in test_client control_server relay_server; do go build -o $binary/$binary $binary/*.go; done + +ENTRYPOINT ["/go/common/test_suite/system_tests.sh"] \ No newline at end of file diff --git a/test_suite/README.md b/test_suite/README.md index 8db5b8d..31f6017 100644 --- a/test_suite/README.md +++ b/test_suite/README.md @@ -63,6 +63,10 @@ sudo and xargs packages): xargs -a system_test_requirements.txt sudo apt-get install +Optionally, the system tests can be run in parallel. In this mode the +system tests are distributed over Docker containers, so it requires +installing [Docker Engine](https://docs.docker.com/engine/install/). + ### Performance test-specific requirements The performance tests are run as a part of the system tests, and require @@ -106,6 +110,28 @@ physical setup for two reasons: network congestion than a physical network. The simulated network setup is described in detail in the next section. +As mentioned in the [system test +requirements](#system-test-specific-requirements), the tests may be run +in parallel using Docker. The user can specify the amount of “threads” +with the `-t` flag, which determines over how many Docker containers the +tests will be distributed. The reason for using Docker is that it allows +the concurrent tests to be executed in isolated networks. Naturally, +running the tests in parallel allows the tests to run much faster. One +disadvantage of the parallel tests is that a Docker image must be built +before running the tests. This takes quite a while the first time the +image is built, but by making use of Docker’s caching, building the +image again after revisions to the code is significantly faster. + +The parallel system tests are also being used in the CI GitHub workflow +when new code is pushed to a branch. For pull requests, the sequential +version of the tests is still used because we cannot simply clone the +branch head inside the Docker image in this case. Currently, running the +system tests in parallel in CI does not the system test job to run +faster. The reason for this is that loading and storing the Docker image +from the GitHub Actions cache takes too long for the speed-up in +actually running the tests to matter. This problem could potentially be +solved by deploying a self-hosted runner with persistent memory to +perform the CI workflow. ### Network Simulation Setup @@ -455,12 +481,10 @@ configuring the following parameters: [\[8\]](#ref-man_netem)). - **Performance test baseline**: with this optional parameter, a - ‘baseline’ is added to the performance test results, created by - repeating the performance tests for: - - 1. two peers that use WireGuard; - 2. two peers that use their physical IP addresses in the simulated - network setup. + ‘baseline’ is added to the performance test results. This baseline is + created by repeating the performance tests for two peers that use + WireGuard, and/or two peers that use their physical IP addresses in + the simulated network setup. With this baseline, it is easier to investigate whether any performance deficiencies in eduP2P are truly the result of a problem @@ -474,6 +498,12 @@ configuring the following parameters: amount of values assigned to the independent variable. If the baseline parameter is used, the duration is additionally multiplied by 3. +- **Performance test repetition**: the result of the performance tests + may be affected by external factors such as other processes running on + the same machine. To mitigate these undesirable external influences, + this parameter allows the performance tests to be repeated multiple + times in order to improve the reliability of their results. + To run performance tests manually, [system_tests.sh](system_tests.sh) can be used with the `-f` option to specify a file containing system tests, which may use the above parameters to also execute a performance @@ -486,7 +516,7 @@ The following command runs tests from a file named Suppose `performance_test.txt` contains the following line: - run_system_test -k bitrate -v 100,200 -d 5 TS_PASS_DIRECT router1-router2 : : + run_system_test -k bitrate -v 100,200 -d 5 -b wireguard -r 3 TS_PASS_DIRECT router1-router2 : : Then, running the system tests with the `-f performance_test.txt` option will execute a performance test with the following parameters: @@ -494,6 +524,9 @@ will execute a performance test with the following parameters: - the independent variable to be tested is bitrate; - the values it should take are 100 and 200 Mbps; - the duration of the test for each value is 5 seconds. +- the performance test is executed for two peers using WireGuard, + besides the standard execution for two peers using eduP2P. +- the performance test is executed 3 times. The other parameters in `performance_test.txt` are not relevant to the performance test itself, but are necessary to run the system test in the @@ -511,9 +544,11 @@ during the test. Using the Python script [visualize_performance_tests.py](visualize_performance_tests.py), these performance metrics are extracted from the json files, and graphs are automatically created that plot the independent variable on the X axis -against each performance metric on the Y axis. Some of these graphs are -shown in the [performance test results -section](#performance-test-results) +against each performance metric on the Y axis. Furthermore, if `-r` has +a value greater than one, another graph is created to show the variance +of the measurements across the different repetitions of the test. Some +of these graphs are shown in the [performance test results +section](#performance-test-results). ## Integration Tests @@ -1151,7 +1186,7 @@ reproducibility. Command used: - run_system_test -k bitrate -v 800,1600,2400,3200,4000 -d 3 -b TS_PASS_DIRECT router1-router2 : wg0:wg0 + run_system_test -k bitrate -v 800,1600,2400,3200,4000 -d 3 -b both -r 5 TS_PASS_DIRECT router1-router2 : wg0:wg0 With this command, we compare the performance of eduP2P, WireGuard and a direct connection between two peers in the test suite’s network setup. @@ -1162,8 +1197,8 @@ limit, as seen in the graph below: ![](./images/performance_tests/ext_wg_x_bitrate_y_bitrate.png) With the direct connection, the maximum bitrate that can be reached on -my machine is approximately 3600 Mbps, whereas eduP2P and WireGuard both -end at a bitrate of approximately 2800 Mbps. +my machine is approximately 3400 Mbps, whereas eduP2P and WireGuard both +end at a bitrate of approximately 2700 Mbps. As the measured bitrate increases, it becomes clear that there are large differences between the packet loss the three connections suffer, as @@ -1172,9 +1207,10 @@ seen in the following graph: ![](./images/performance_tests/ext_wg_x_bitrate_y_packet_loss.png) The direct connection does not suffer any packet loss, whereas -WireGuard’s packet loss slowly climbs up to approximately 3%, and the -packet loss of eduP2P quickly increases to end at over 50%. +WireGuard’s packet reaches a maximum of approximately 5%, and the packet +loss of eduP2P quickly increases to end at over 70%. +The root cause of eduP2P’s packet loss has yet to be determined. Although the final amount of packet loss in eduP2P seems very alarming, it must be noted that the maximum bitrate used in this performance test is very high, and when eduP2P would be used in the real world it is @@ -1182,12 +1218,6 @@ unlikely that the network bandwidth limits would allow for such a high bitrate. However, the packet loss of eduP2P is also quite sizeable even for lower bitrates, which would be a problem in the real world. -The reason that eduP2P suffers so much packet loss probably has to do -with the fact that it uses Go channels internally to pass packets -between its isolated components. For such high bitrates, these channels -may become full, and consequently packets sent to these channels are -dropped. - It must also be noted that although the direct connection and WireGuard do not suffer much packet loss on my machine, this is not the case on every machine I tried this performance test on. When repeating this test @@ -1209,7 +1239,7 @@ differ on other machines. Command used: - run_system_test -k bitrate -v 800,1600,2400,3200,4000 -d 3 -b TS_PASS_DIRECT router1-router2 : : + run_system_test -k bitrate -v 800,1600,2400,3200,4000 -d 3 -b both -r 5 TS_PASS_DIRECT router1-router2 : : This command repeats the performance test of the previous section, with the only difference being that now both peers use userspace WireGuard @@ -1230,7 +1260,7 @@ further, however: Command used: - run_system_test -k delay -v 0,1,2,3 -d 3 -b TS_PASS_DIRECT router1-router2 : : + run_system_test -k delay -v 0,1,2,3 -d 3 -b both -r 3 TS_PASS_DIRECT router1-router2 : : ![](./images/performance_tests/x_ow_delay_y_http_latency.png) @@ -1252,6 +1282,36 @@ eduP2P into account, we can conclude that increasing the one-way delay does not increase eduP2P’s HTTP latency more than the HTTP latency of WireGuard or the direct connection. +### Consistency of results + +The results of the performance tests may be affected by external +factors, such as other processes running on the same machine. Therefore, +the performance test results may be inconsistent: two runs of the same +performance tests may have different results. + +To improve the reliability of the performance tests, the `-r` option has +been introduced to repeat the tests and aggregate their results. The +performance tests also generate a graph illustrating the variance over +multiple repetitions. Below, this graph is shown for the performance +test described in [bitrate performance test with external +WireGuard](./README.md#results-with-varying-bitrate-and-peers-using-external-wireguard). + +![](./images/performance_tests/performance_test_variance.png) + +This graph shows that there is quite a lot of variance between certain +measurements. For example, for the eduP2P bitrate measurement (top-left) +and WireGuard packet loss measurement (bottom-center), the absolute +difference between the minimum and maximum measurements grows quite +large as the target bitrate increases. + +Some of the measurements also contain outliers, such as the WireGuard +jitter measurement (center) with large spikes in the third and fifth +repetition. + +As seen from the black lines in the graphs, calculating the average over +the repetitions improves the reliability of the results where the +variance is large or outliers are present. + ## Integration Test Results Currently, the integration tests focus on the lowest level components @@ -1394,18 +1454,3 @@ Conservancy](https://commonsconservancy.org/). The test suite features that have been made possible thanks to this funding are described below. - -### Simulating network delay (finished 04-03-2025) - -This feature makes it possible to add artificial network delay in the -system and performance tests. - -The feature can be used with the system tests by calling -`system_tests.sh` with the option `-d `. - -In the performance tests, this artificial delay can be configured as the -independent test variable. More details are given in the [performance -test documentation](./README.md#performance-tests). Furthermore, the -effect of the artificial delay on the eduP2P network performance is -reported in the [performance test -results](./README.md#results-with-varying-one-way-delay). diff --git a/test_suite/compare_performance.py b/test_suite/compare_performance.py new file mode 100644 index 0000000..b8c7ec4 --- /dev/null +++ b/test_suite/compare_performance.py @@ -0,0 +1,128 @@ +import json +import os +import sys +from pathlib import Path + +# Exit codes +EXIT_PERFORMANCE_SIMILAR = 0 +EXIT_COMPARISON_FAILED = 1 +EXIT_PERFORMANCE_WORSE = 1 +EXIT_PERFORMANCE_BETTER = 0 + +# Ensure both parameters are provided +if len(sys.argv) - 1 != 2: + print(f""" +Usage: python {sys.argv[0]} + +The two parameters should be either system test logs containing only performance tests, or extracted performance-test-data artifacts from GitHub Actions + +The output of this script is formatted as markdown such that it can be used in GitHub job step summaries""") + exit(EXIT_COMPARISON_FAILED) + +baseline=sys.argv[1] +new=sys.argv[2] + +# This dictionary defines which measurements to compare, and when to consider two measurements worse/better/similar +COMPARISON_CONFIG = { + "Target bitrate": { + "packet_loss": { + "better": lambda new, baseline: new < 0.8 * baseline and new < baseline - 5, + "worse": lambda new, baseline: new > 1.2 * baseline and new > baseline + 5 + } + } +} + +# Keep track of whether performance is worse or better +performance_worse = False +performance_better = False + +def failure(reason: str): + print("# ❌ Performance comparison failed") + print(reason) + exit(EXIT_COMPARISON_FAILED) + +def check_same_performance_test(new_data: dict, baseline_data: dict, rel_path: str): + """Check whether the two data files contain the same performance test, otherwise comparison is not possible""" + + same_test_var = new_data["test_var"] == baseline_data["test_var"] + + if not same_test_var: + failure(f"mismatch in test variable or its values for {rel_path}") + + baseline_metrics = set(baseline_data["measurements"].keys()) + new_metrics = set(new_data["measurements"].keys()) + all_required_metrics = baseline_metrics <= new_metrics + + if not all_required_metrics: + failure(f"for {rel_path}, some of the measurements in the baseline data are not present in the new data") + +def report_performance_change(better: bool, metric: str, idx: int, test_var: str, test_var_values: list[float], baseline: float, new: float): + change = "improved" if better else "degraded" + print(f"- For {test_var} = {test_var_values[idx]}, {metric} {change}: {baseline:.1f} -> {new:.1f}") + +def compare_measurements(new_data: dict, baseline_data: dict): + test_var = baseline_data["test_var"]["label"] + test_var_values = baseline_data["test_var"]["values"] + + if not(test_var in COMPARISON_CONFIG.keys()): + return + + # Keep track of performance difference by modifying the global variables + global performance_worse, performance_better + + metrics_to_compare = COMPARISON_CONFIG[test_var].keys() + new_measurements = new_data["measurements"] + baseline_measurements = baseline_data["measurements"] + + for metric in metrics_to_compare: + metric_label = baseline_measurements[metric]["label"] + new_values = new_measurements[metric]["values"]["average"]["eduP2P"] + baseline_values = baseline_measurements[metric]["values"]["average"]["eduP2P"] + is_worse = COMPARISON_CONFIG[test_var][metric]["worse"] + is_better = COMPARISON_CONFIG[test_var][metric]["better"] + + for i, (new_val, baseline_val) in enumerate(zip(new_values, baseline_values)): + if is_worse(new_val, baseline_val): + report_performance_change(False, metric_label, i, test_var, test_var_values, baseline_val, new_val) + performance_worse = True + + if is_better(new_val, baseline_val): + report_performance_change(True, metric_label, i, test_var, test_var_values, baseline_val, new_val) + performance_better = True and not performance_worse # Worse performance has higher priority than better performance + +# Iterate over all data files from baseline performance test data +cwd = os.getcwd() +baseline_files = Path(f"{cwd}/{baseline}").rglob("performance_test_data.json*") +print("# Comparison details") + +for path in baseline_files: + path = str(path) + + # Get relative path by removing current working directory + baseline directory prefix + rel_path = path[len(cwd) + len(baseline) + 1:] + print(f"### Comparing {rel_path}...") + + # Attempt to open same performance test file in new data + try: + with open(f"{cwd}/{new}/{rel_path}") as f_new: + new_data = json.load(f_new) + except FileNotFoundError: + failure(f"{rel_path} is present in {baseline}, but not in {new}") + + with open(path) as f_baseline: + baseline_data = json.load(f_baseline) + + check_same_performance_test(new_data, baseline_data, rel_path) + compare_measurements(new_data, baseline_data) + +# Print final conclusion about performance +if performance_worse: + print(f"# 📉 Total performance has degraded") + exit(EXIT_PERFORMANCE_WORSE) +elif performance_better: + print(f"# 📈 Total performance has improved!") + exit(EXIT_PERFORMANCE_BETTER) + +print(f"# ✅ No significant performance change") +exit(EXIT_PERFORMANCE_SIMILAR) + diff --git a/test_suite/control_server/main.go b/test_suite/control_server/main.go index aa4e62e..64eced6 100644 --- a/test_suite/control_server/main.go +++ b/test_suite/control_server/main.go @@ -25,10 +25,8 @@ import ( ) var ( - //dev = flag.Bool("dev", false, "run in localhost development mode (overrides -a)") addr = flag.String("a", ":443", "server HTTP/HTTPS listen address, in form \":port\", \"ip:port\", or for IPv6 \"[ip]:port\". If the IP is omitted, it defaults to all interfaces. Serves HTTPS if the port is 443 and/or -certmode is manual, otherwise HTTP.") configPath = flag.String("c", "", "config file path") - //stunPort = flag.Int("stun-port", stunserver.DefaultPort, "The UDP port on which to serve STUN. The listener is bound to the same IP (if any) as specified in the -a flag.") programLevel = new(slog.LevelVar) // Info by default ) @@ -36,7 +34,6 @@ var ( func main() { h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: programLevel, - //AddSource: true, }) slog.SetDefault(slog.New(h)) programLevel.Set(-8) @@ -54,21 +51,19 @@ func main() { mux.Handle("/control", controlhttp.ServerHandler(cserver.server)) - // TODO below is dup from relayserver main.go; dedup in a common library? - mux.Handle("/", handleStaticHTML(ToverSokControlDefaultHTML)) - mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) - io.WriteString(w, "User-agent: *\nDisallow: /\n") + if _, err := io.WriteString(w, "User-agent: *\nDisallow: /\n"); err != nil { + slog.Error("could not write robots.txt", "err", err) + } })) mux.Handle("/generate_204", http.HandlerFunc(serverCaptivePortalBuster)) httpsrv := &http.Server{ Addr: *addr, Handler: mux, - // TODO - //ErrorLog: slog.NewLogLogger(), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -76,16 +71,16 @@ func main() { go func() { <-ctx.Done() - httpsrv.Shutdown(ctx) + if err := httpsrv.Shutdown(ctx); err != nil { + slog.Error("control: failed to shutdown control server", "error", err) + } }() - // TODO setup TLS with autocert? - slog.Info("control: serving", "addr", *addr) err := httpsrv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("control: error %s", err) + log.Fatalf("control: error %s", err) //nolint:gocritic } } @@ -99,7 +94,7 @@ type ControlServer struct { } func (cs *ControlServer) OnSessionCreate(id control.SessID, cid control.ClientID) { - println("OnSessionCreate") + slog.Info("OnSessionCreate", "id", id, "cid", cid) go func() { if err := cs.server.AcceptAuthentication(id); err != nil { @@ -108,25 +103,24 @@ func (cs *ControlServer) OnSessionCreate(id control.SessID, cid control.ClientID }() } -func (cs *ControlServer) OnSessionResume(id control.SessID, id2 control.ClientID) { - println("OnSessionResume") - return // noop +func (cs *ControlServer) OnSessionResume(sess control.SessID, cid control.ClientID) { + slog.Info("OnSessionResume", "sess", sess, "cid", cid) } -func (cs *ControlServer) OnDeviceKey(id control.SessID, key string) { - println("OnDeviceKey") - return // noop +func (cs *ControlServer) OnDeviceKey(sess control.SessID, deviceKey string) { + slog.Info("OnDeviceKey", "sess", sess, "deviceKey", deviceKey) } -func (cs *ControlServer) OnSessionFinalize(id control.SessID, id2 control.ClientID) (netip.Prefix, netip.Prefix) { - println("OnSessionFinalize") +func (cs *ControlServer) OnSessionFinalize(sess control.SessID, cid control.ClientID) (netip.Prefix, netip.Prefix, time.Time) { + slog.Info("OnSessionFinalize", "sess", sess, "cid", cid) + + ip4, ip6 := cs.getIPs(key.NodePublic(cid)) - return cs.getIPs(key.NodePublic(id2)) + return ip4, ip6, time.Time{} } -func (cs *ControlServer) OnSessionDestroy(id control.SessID, id2 control.ClientID) { - println("OnSessionDestroy") - return // noop +func (cs *ControlServer) OnSessionDestroy(sess control.SessID, cid control.ClientID) { + slog.Info("OnSessionDestroy", "sess", sess, "cid", cid) } func LoadServer(ctx context.Context) *ControlServer { @@ -183,7 +177,7 @@ func (cs *ControlServer) addNewNode(node key.NodePublic) { } } -func (cs *ControlServer) isKnown(node key.NodePublic) bool { +func (cs *ControlServer) isKnown(node key.NodePublic) bool { //nolint:unused cs.cfgMu.Lock() defer cs.cfgMu.Unlock() @@ -234,7 +228,6 @@ func findNewIP(ipp netip.Prefix, used func(netip.Addr) bool) (netip.Prefix, neti // we exceeded the boundary, try a back-sweep backwards = true } else { - // TODO find better way to deal with this panic("address space exhausted") } } @@ -287,14 +280,16 @@ func handleStaticHTML(doc string) http.HandlerFunc { } } -func sendStaticHTML(doc string, w http.ResponseWriter, r *http.Request) { +func sendStaticHTML(doc string, w http.ResponseWriter, _ *http.Request) { browserHeaders(w) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) - io.WriteString(w, doc) + if _, err := io.WriteString(w, doc); err != nil { + slog.Error("failed to write static HTML page", "error", err) + } } const ToverSokControlDefaultHTML = ` @@ -347,6 +342,7 @@ func loadConfig() Config { return writeNewConfig() case err != nil: log.Fatal(err) + //goland:noinspection GoUnreachableCode panic("unreachable") default: var cfg Config @@ -366,14 +362,14 @@ func writeNewConfig() Config { } func writeConfig(cfg Config, path string) { - if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o777); err != nil { log.Fatal(err) } b, err := json.MarshalIndent(cfg, "", "\t") if err != nil { log.Fatal(err) } - if err := os.WriteFile(path, b, 0600); err != nil { + if err := os.WriteFile(path, b, 0o600); err != nil { log.Fatal(err) } } @@ -382,7 +378,6 @@ func newConfig() Config { return Config{ ControlKey: key.NewControlPrivate(), - //// TODO REPLACE WITH CONFIGURABLE VALUES IP4: netip.MustParsePrefix("10.42.0.0/16"), IP6: netip.MustParsePrefix("fd42:dead:beef::/64"), diff --git a/test_suite/images/performance_tests/ext_wg_x_bitrate_y_bitrate.png b/test_suite/images/performance_tests/ext_wg_x_bitrate_y_bitrate.png index 0c481db..d29f3e0 100644 Binary files a/test_suite/images/performance_tests/ext_wg_x_bitrate_y_bitrate.png and b/test_suite/images/performance_tests/ext_wg_x_bitrate_y_bitrate.png differ diff --git a/test_suite/images/performance_tests/ext_wg_x_bitrate_y_packet_loss.png b/test_suite/images/performance_tests/ext_wg_x_bitrate_y_packet_loss.png index 89208be..1205a5a 100644 Binary files a/test_suite/images/performance_tests/ext_wg_x_bitrate_y_packet_loss.png and b/test_suite/images/performance_tests/ext_wg_x_bitrate_y_packet_loss.png differ diff --git a/test_suite/images/performance_tests/performance_test_variance.png b/test_suite/images/performance_tests/performance_test_variance.png new file mode 100644 index 0000000..92c582b Binary files /dev/null and b/test_suite/images/performance_tests/performance_test_variance.png differ diff --git a/test_suite/images/performance_tests/usr_wg_x_bitrate_y_bitrate.png b/test_suite/images/performance_tests/usr_wg_x_bitrate_y_bitrate.png index 433eec8..f1e72c3 100644 Binary files a/test_suite/images/performance_tests/usr_wg_x_bitrate_y_bitrate.png and b/test_suite/images/performance_tests/usr_wg_x_bitrate_y_bitrate.png differ diff --git a/test_suite/images/performance_tests/usr_wg_x_bitrate_y_packet_loss.png b/test_suite/images/performance_tests/usr_wg_x_bitrate_y_packet_loss.png index 8922c6a..64dc020 100644 Binary files a/test_suite/images/performance_tests/usr_wg_x_bitrate_y_packet_loss.png and b/test_suite/images/performance_tests/usr_wg_x_bitrate_y_packet_loss.png differ diff --git a/test_suite/performance_test.sh b/test_suite/performance_test.sh index ffc2737..12096f2 100755 --- a/test_suite/performance_test.sh +++ b/test_suite/performance_test.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash usage_str=""" -Usage: ${0} [OPTIONAL ARGUMENTS] +Usage: ${0} [OPTIONAL ARGUMENTS] [OPTIONAL ARGUMENTS]: - -b - With this flag, eduP2P's performance is compared to the performance of a direct connection, and a connection using only WireGuard + -b + With this flag, eduP2P's performance is compared to the performance of a direct connection and/or a connection using only WireGuard This flag should only be used when both peers reside in the 'public' network Executes performance tests between the peers using iperf3, where peer 1 acts as the server and peer 2 as the client @@ -24,10 +24,24 @@ This script must be executed with root permissions . ./util.sh # Validate optional arguments -while getopts ":bh" opt; do +while getopts ":b:h" opt; do case $opt in b) - baseline=true + baseline=$OPTARG + validate_str $baseline "^direct|wireguard|both$" + + case $baseline in + "direct") + baseline_direct=true + ;; + "wireguard") + baseline_wireguard=true + ;; + "both") + baseline_direct=true + baseline_wireguard=true + ;; + esac ;; h) echo "$usage_str" @@ -43,7 +57,7 @@ done shift $((OPTIND-1)) # Make sure all required arguments have been passed -if [[ $# -ne 7 ]]; then +if [[ $# -ne 8 ]]; then exit_with_error "expected 7 positional parameters, but received $#" fi @@ -53,13 +67,14 @@ peer1_ip=$3 performance_test_var=$4 performance_test_values=$5 performance_test_duration=$6 -log_dir=$7 +performance_test_reps=$7 +log_dir=$8 function clean_exit() { exit_code=$1 # Delete baseline WireGuard interfaces and private keys, and kill keep-alive process - if [[ $baseline == true ]]; then + if [[ $baseline_wireguard == true ]]; then for ns in $peer1 $peer2; do sudo ip netns exec $ns ip link del wg_$ns rm private_$ns @@ -68,8 +83,8 @@ function clean_exit() { kill $keep_alive_pid fi - # Undo the restrictive permissions which iperf3 sets on test_dir - chmod --recursive 777 $test_dir + # Undo the restrictive permissions which iperf3 sets on subdirectories of log_dir + chmod --recursive 777 $log_dir exit $exit_code } @@ -171,8 +186,49 @@ function performance_test() { wait $server_pid # Measure delay and store it in the iperf3 log file - delay=$(measure_delay $server_ip) - store_delay $delay $log_path + if [[ $performance_test_var != "bitrate" ]]; then + delay=$(measure_delay $server_ip) + store_delay $delay $log_path + fi +} + +# Function to do performance tests for all performance test values +function performance_tests() { + performance_test_value_array=$1 + performance_test_dir=$2 + performance_test_rep=$3 + + # String describing the current repetition, empty if only one repetition is performed + if [[ $performance_test_reps -gt 1 ]]; then + rep_description="Repetition $performance_test_rep/$performance_test_reps: " + fi + + # Variables to display a progress bar + n_values=${#performance_test_value_array[@]} + progress=0 + + # Iterate over performance test values + for performance_test_val in ${performance_test_value_array[@]}; do + bar=$(progress_bar $progress $n_values) + echo -ne "\033[2K\t$rep_description$bar Performance testing with $performance_test_var = $performance_test_val\r" # \033[2K = Ctrl+K, clears rest of line from cursor; \r returns to beginning of line + + # Run performance test for eduP2P + performance_test $performance_test_val $performance_test_dir "eduP2P" $peer1_ip + + # If -b is set, the performance test is repeated over a direct/WireGuard connection instead of over the eduP2P connection + if [[ $baseline_direct == true ]]; then + performance_test $performance_test_val $performance_test_dir "Direct" $peer1_pub_ip + fi + + if [[ $baseline_wireguard == true ]]; then + performance_test $performance_test_val $performance_test_dir "WireGuard" 10.0.0.1 + fi + + let "progress++" + done + + bar=$(progress_bar $n_values $n_values) + echo -e "\t$rep_description$bar Performance testing with $performance_test_var finished" } # Set up WireGuard connection between the peers (for performance test baseline) @@ -207,19 +263,8 @@ function wg_setup() { ip netns exec $peer2 wg set wg_$peer2 peer $pub1 allowed-ips 10.0.0.1/32 endpoint 192.168.1.254:$port1 } -# Directory to store performance test results -performance_test_dir=$log_dir/performance_tests_$performance_test_var - -# Replace commas by spaces to convert string to array -performance_test_values=$(echo $performance_test_values | tr ',' ' ') -performance_test_values=($performance_test_values) - -# Variables to display a progress bar -n_values=${#performance_test_values[@]} -progress=0 - # For the baseline comparison, we need the peers' public IPs, which are also needed to setup a WireGuard connection between them -if [[ $baseline == true ]]; then +if [[ $baseline_direct == true || $baseline_wireguard == true ]]; then peer1_pub_ip=$(ip netns exec $peer1 ip address | grep -Eo "inet 192.168.[0-9.]+" | cut -d ' ' -f2) peer2_pub_ip=$(ip netns exec $peer2 ip address | grep -Eo "inet 192.168.[0-9.]+" | cut -d ' ' -f2) @@ -234,24 +279,16 @@ if [[ $baseline == true ]]; then keep_alive_pid=$! fi -# Iterate over performance test values -for performance_test_val in ${performance_test_values[@]}; do - bar=$(progress_bar $progress $n_values) - echo -ne "\033[2K\t$bar Performance testing with $performance_test_var = $performance_test_val\r" # \033[2K = Ctrl+K, clears rest of line from cursor; \r returns to beginning of line - - # Run performance test for eduP2P - performance_test $performance_test_val $performance_test_dir "eduP2P" $peer1_ip +# Replace commas by spaces to convert string to array +performance_test_value_array=$(echo $performance_test_values | tr ',' ' ') +performance_test_value_array=($performance_test_value_array) - # If -b is set, the performance test is repeated over a direct/WireGuard connection instead of over the eduP2P connection - if [[ $baseline == true ]]; then - performance_test $performance_test_val $performance_test_dir "Direct" $peer1_pub_ip - performance_test $performance_test_val $performance_test_dir "WireGuard" 10.0.0.1 - fi +for ((i=1;i<=$performance_test_reps;i++)); do + # Directory to store performance test results for this repetition + performance_test_dir=$log_dir/performance_tests_$performance_test_var/repetition$i - let "progress++" + performance_tests $performance_test_value_array $performance_test_dir $i done -bar=$(progress_bar $n_values $n_values) -echo -e "\t$bar Performance testing with $performance_test_var finished" clean_exit 0 diff --git a/test_suite/relay_server/main.go b/test_suite/relay_server/main.go index 3338c9a..dbe376c 100644 --- a/test_suite/relay_server/main.go +++ b/test_suite/relay_server/main.go @@ -6,10 +6,6 @@ import ( "errors" "flag" "fmt" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/relay" - "github.com/edup2p/common/types/relay/relayhttp" - stunserver "github.com/edup2p/common/types/stun" "io" "log" "log/slog" @@ -22,6 +18,11 @@ import ( "strings" "syscall" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/relay" + "github.com/edup2p/common/types/relay/relayhttp" + stunserver "github.com/edup2p/common/types/stun" ) var ( @@ -45,9 +46,6 @@ const ToverSokRelayDefaultHTML = ` func main() { flag.Parse() - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - if *dev { *addr = "127.0.0.1:3340" log.Printf("Running in dev mode.") @@ -71,8 +69,16 @@ func main() { log.Fatalf("could not parse stun-combined addrport: %v", err) } + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + stunServer := stunserver.NewServer(ctx) - go stunServer.ListenAndServe(ap) + go func() { + if err := stunServer.ListenAndServe(ap); err != nil { + // This is okay, because running a STUN server is basically also the entire point of the relay server + panic(err) + } + }() // TODO add STUN here @@ -88,19 +94,23 @@ func main() { mux.Handle("/relay", relayhttp.ServerHandler(server)) - mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) - io.WriteString(w, ToverSokRelayDefaultHTML) + if _, err := io.WriteString(w, ToverSokRelayDefaultHTML); err != nil { + slog.Error("Failed to write default HTML response", "err", err) + } })) - mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { browserHeaders(w) - io.WriteString(w, "User-agent: *\nDisallow: /\n") + if _, err := io.WriteString(w, "User-agent: *\nDisallow: /\n"); err != nil { + slog.Error("Failed to write robots.txt response", "err", err) + } })) mux.Handle("/generate_204", http.HandlerFunc(serverCaptivePortalBuster)) @@ -108,7 +118,7 @@ func main() { Addr: *addr, Handler: mux, // TODO - //ErrorLog: slog.NewLogLogger(), + // ErrorLog: slog.NewLogLogger(), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -116,7 +126,9 @@ func main() { go func() { <-ctx.Done() - httpsrv.Shutdown(ctx) + if err := httpsrv.Shutdown(ctx); err != nil { + slog.Error("Failed to shutdown server", "err", err) + } }() // TODO setup TLS with autocert @@ -125,7 +137,7 @@ func main() { err = httpsrv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("relay: error %s", err) + log.Fatalf("relay: error %s", err) //nolint:gocritic } } @@ -183,6 +195,7 @@ func loadConfig() Config { return writeNewConfig() case err != nil: log.Fatal(err) + //goland:noinspection GoUnreachableCode panic("unreachable") default: var cfg Config @@ -194,7 +207,7 @@ func loadConfig() Config { } func writeNewConfig() Config { - if err := os.MkdirAll(filepath.Dir(*configPath), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(*configPath), 0o777); err != nil { log.Fatal(err) } cfg := newConfig() @@ -202,7 +215,7 @@ func writeNewConfig() Config { if err != nil { log.Fatal(err) } - if err := os.WriteFile(*configPath, b, 0600); err != nil { + if err := os.WriteFile(*configPath, b, 0o600); err != nil { log.Fatal(err) } return cfg diff --git a/test_suite/system_test.sh b/test_suite/system_test.sh index 6e473e2..43d0a82 100755 --- a/test_suite/system_test.sh +++ b/test_suite/system_test.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# Amount of seconds to wait for one system test to finish +SYSTEM_TEST_TIMEOUT=60 + usage_str=""" Usage: ${0} [OPTIONAL ARGUMENTS] [NAT CONFIGURATION 1]:[NAT CONFIGURATION 2] [WIREGUARD INTERFACE 1]:[WIREGUARD INTERFACE 2] @@ -12,8 +15,9 @@ Usage: ${0} [OPTIONAL ARGUMENTS] [NAT CO -k -v -d - -b - With this flag, eduP2P's performance is compared to the performance of a direct connection, and a connection using only WireGuard + -r + -b + With this flag, eduP2P's performance is compared to the performance of a direct connection and/or a connection using only WireGuard This flag should only be used when both peers reside in the 'public' network specifies the peer and router namespaces to be used in this system test. It should be a string with one of the following formats: @@ -33,15 +37,16 @@ If [WIREGUARD INTERFACE 1] or [WIREGUARD INTERFACE 2] is not provided, the corre is a string of IP addresses separated by a space that may be the destination IP of packets crossing this NAT device, and is necessary to simulate an Address-Dependent Mapping - should be one of {trace|debug|info} (in order of most to least log messages), but can NOT be info if one if the peers is using userspace WireGuard (then IP of the other peer is not logged)""" + should be one of {trace|debug|info|warn|error}, and MUST be trace/debug if one of the peers uses userspace WireGuard (the other peer's IP address is not logged otherwise)""" # Use functions and constants from util.sh . ./util.sh -performance_test_duration=0 # Default value in case -d is not used +performance_test_duration=5 # Default value in case -d is not used +performance_test_reps=1 # Default value in case -r is not used # Validate optional arguments -while getopts ":k:v:d:bh" opt; do +while getopts ":k:v:d:r:b:h" opt; do case $opt in k) performance_test_var=$OPTARG @@ -61,8 +66,20 @@ while getopts ":k:v:d:bh" opt; do performance_test_duration=$OPTARG validate_str $performance_test_duration "^[0-9]+$" ;; + r) + performance_test_reps=$OPTARG + validate_str $performance_test_duration "^[0-9]+$" + + if [[ $performance_test_reps -eq 0 ]]; then + exit_with_error "value of -r should be at least 1" + fi + ;; + b) - baseline="-b" + performance_test_baseline=$OPTARG + validate_str $performance_test_baseline "^direct|wireguard|both$" + + baseline="-b $performance_test_baseline" ;; h) echo "$usage_str" @@ -145,6 +162,11 @@ wg_interface_regex="^([^:]*):([^:]*)$" validate_str $wg_interface_str $wg_interface_regex wg_interfaces=(${BASH_REMATCH[1]} ${BASH_REMATCH[2]}) +# Remove conntrack entries from potential previous tests +for router_ns in ${router_ns_list[@]}; do + sudo ip netns exec $router_ns conntrack -D &> /dev/null +done + # Prepare a string describing the NAT setup NAT_TYPES=("EI" "AD" "APD") @@ -217,6 +239,9 @@ function clean_exit() { # Kill background processes, such as the setup_client.sh scripts sudo kill $(jobs -p) &> /dev/null + # Remove restrictive permissions on certain log files + sudo chmod --recursive 777 $log_dir + exit $exit_code } @@ -249,7 +274,7 @@ for i in {0..1}; do touch $peer_logfile # Make sure file already exists so tail command later in script does not fail sudo ip netns exec $peer_ns ./setup_client.sh `# Run script in peer's network namespace` \ - $peer_id $peer_ns $test_target $control_pub_key $control_ip $control_port $log_lvl ${wg_interfaces[$i]} `# Positional parameters` \ + $peer_id $peer_ns $test_target $control_pub_key $control_ip $control_port $log_lvl $log_dir ${wg_interfaces[$i]} `# Positional parameters` \ 2>&1 | tee $peer_logfile &> /dev/null & # Combination of tee and redirect to /dev/null is necessary to avoid weird behaviour caused by redirecting a script run with sudo done @@ -257,7 +282,7 @@ done for i in {0..1}; do peer_id="peer$((i+1))" export LOG_FILE=${log_dir}/$peer_id.txt # Export to use in bash -c - timeout 30s bash -c 'tail -n +1 -f $LOG_FILE | sed -n "/TS_PASS/q2; /TS_FAIL/q3"' # bash -c is necessary to use timeout with | and still get the right exit codes + timeout ${SYSTEM_TEST_TIMEOUT}s bash -c 'tail -f -n +1 -s0.1 $LOG_FILE | sed -n "/TS_PASS/q2; /TS_FAIL/q3"' # bash -c is necessary to use timeout with | and still get the right exit codes # Branch on exit code of previous command case $? in @@ -300,7 +325,7 @@ if [[ -n $performance_test_var ]]; then peer1_ns=${peer_ns_list[0]} peer1_ip=$(extract_ipv4 $peer1_ns $peer1_interface) - sudo ./performance_test.sh $baseline $peer1_ns ${peer_ns_list[1]} $peer1_ip $performance_test_var $performance_test_values $performance_test_duration $log_dir + sudo ./performance_test.sh ${baseline} $peer1_ns ${peer_ns_list[1]} $peer1_ip $performance_test_var $performance_test_values $performance_test_duration $performance_test_reps $log_dir if [[ $? -ne 0 ]]; then clean_exit 1 diff --git a/test_suite/system_tests.sh b/test_suite/system_tests.sh index 3900f84..719fc18 100755 --- a/test_suite/system_tests.sh +++ b/test_suite/system_tests.sh @@ -22,9 +22,17 @@ The following options can be used to configure additional parameters during the -d Add delay to packets transmitted by the eduP2P clients, control server and relay server The delay should be provided as an integer that represents the one-way delay in milliseconds - -l + -l Specifies the log level used in the eduP2P client of the two peers - The log level 'info' should not be used if a system test is run where one of the peers uses userspace WireGuard (the other peer's IP address is not logged in this case)""" + If one of the peers uses userspace WireGuard, the log level trace/debug must be used, since the other peer's IP address is not logged otherwise + -L + Specifies the alphanumeric name of the directory inside system_test_logs/ where the test logs will be stored + If this argument is not provided, the directory name is the current timestamp + -t + Run the system tests in parallel with the specified number of threads. + It is not recommended to combine this flag with -p, as multithreading will likely degrade the performance and the graphs will not be created automatically + -b + Build the client, control server and relay server binaries before running the tests""" # Use functions and constants from util.sh . ./util.sh @@ -33,7 +41,7 @@ The following options can be used to configure additional parameters during the log_lvl="debug" # Validate optional arguments -while getopts ":c:d:ef:l:ph" opt; do +while getopts ":c:d:ef:l:L:t:bph" opt; do case $opt in c) connectivity=true @@ -71,10 +79,24 @@ while getopts ":c:d:ef:l:ph" opt; do l) log_lvl=$OPTARG - # Log level should be info/debug/trace - log_lvl_regex="^info|debug|trace?$" + log_lvl_regex="^trace|debug|info|warn|error?$" validate_str $log_lvl $log_lvl_regex ;; + L) + alphanum_regex="^[a-zA-Z0-9]+$" + validate_str $OPTARG $alphanum_regex + log_dir_rel=system_test_logs/$OPTARG + ;; + t) + n_threads=$OPTARG + + # Make sure n_threads is an integer between 2 and 8 + threads_regex="^[2-8]$" + validate_str $n_threads $int_regex + ;; + b) + build=true + ;; p) performance=true ;; @@ -88,12 +110,23 @@ while getopts ":c:d:ef:l:ph" opt; do esac done -# Shift positional parameters indexing by accounting for the optional arguments -shift $((OPTIND-1)) - # Store repository's root directory for later use repo_dir=$(cd ..; pwd) +function create_log_dir() { + if [[ -z $log_dir_rel ]]; then + timestamp=$(date +"%Y-%m-%dT%H_%M_%S") + log_dir_rel=system_test_logs/${timestamp} # Relative path for pretty printing + fi + + log_dir=${repo_dir}/test_suite/${log_dir_rel} # Absolute path for use in scripts running from different directories + mkdir -p ${log_dir} + echo "Logging to ${log_dir_rel}" +} + +create_log_dir + +# ================================ FUNCTIONS FOR SEQUENTIAL SYSTEM TESTS ================================ function cleanup () { # Kill the two servers if they have already been started by the script sudo pkill control_server @@ -103,9 +136,6 @@ function cleanup () { sudo kill $test_pid &> /dev/null } -# Run cleanup when script exits -trap cleanup EXIT - function build_go() { for binary in test_client control_server relay_server; do binary_dir="${repo_dir}/test_suite/$binary" @@ -113,25 +143,11 @@ function build_go() { done } -build_go - -function create_log_dir() { - timestamp=$(date +"%Y-%m-%dT%H_%M_%S") - log_dir_rel=system_test_logs/${timestamp} # Relative path for pretty printing - log_dir=${repo_dir}/test_suite/${log_dir_rel} # Absolute path for use in scripts running from different directories - mkdir -p ${log_dir} - echo "Logging to ${log_dir_rel}" -} - -create_log_dir - function setup_networks() { cd nat_simulation/ adm_ips=$(sudo ./setup_networks.sh) # setup_networks.sh returns an array of IPs used by hosts in the network simulation setup, this list is needed to simulate a NAT device with an Address-Dependent Mapping } -setup_networks - function extract_server_pub_key() { server_type=$1 # control_server or relay_server ip=$2 @@ -189,27 +205,133 @@ function setup_servers() { start_server "relay_server" $relay_ip $relay_port } -# Choose ports for the control and relay servers, then start them -control_port=9999 -relay_port=3340 -echo "Setting up servers" -setup_servers +function sequential_setup() { + # Run cleanup when script exits + trap cleanup EXIT + + # Go build binaries unless -b flag was specified + if [[ $build == true ]]; then + echo "Building binaries..." + build_go + else + echo "Skipped building binaries" + fi + + setup_networks + + # Choose ports for the control and relay servers, then start them + control_port=9999 + relay_port=3340 + echo "Setting up servers" + setup_servers + + cd $repo_dir/test_suite + + if [[ -n $packet_loss ]]; then + sudo ./set_packet_loss.sh $packet_loss + fi + + if [[ -n $delay ]]; then + sudo ./set_delay.sh $delay + fi + + # Test counters + n_tests=0 + n_failed=0 +} + +# Log messages that should not be printed when the tests are run in parallel +function log_sequential() { + msg=$1 + + if [[ -z $n_threads ]]; then + echo -e $msg + fi +} + +# ================================ FUNCTIONS FOR PARALLEL SYSTEM TESTS ================================ +function parallel_setup() { + echo """ +Dividing the system tests among $n_threads threads. The output of each thread can be found in the logs.""" + + # The current system tests command will be run in parallel docker containers with a few modifications: + system_test_opts=$(echo $@ | sed -r -e "s/-f \S+//" `# Potential -f flag is removed, as each docker container will be assigned a file containing a subset of the current system tests` \ + -e "s/-t [2-8]//") # -t flag is removed, since each docker container will run the tests in parallel` + + # Tests will be assigned to the containers in a round-robin manner, so we keep track of the current thread + current_thread=0 + + # Arrays of length n_threads representing number of assigned and completed tests per thread, initialized to 0 + assigned=() + completed=() + + for i in $(seq 1 $n_threads); do + # Initialize above arrays to 0 + assigned+=(0) + completed+=(0) + + # Create a directory and file for each thread to store the system test logs and commands + mkdir $log_dir/thread$i + touch $log_dir/thread$i/tests.txt + done +} + +function log_parallel() { + thread=$1 + msg=$2 + + let "move_cursor = n_threads - thread" -# Test counters -n_tests=0 -n_failed=0 + # Move cursor up to the line for the current thread + echo -ne "\033[${move_cursor}A\tThread $((thread+1)): $msg\r" + + # Move cursor back to original position + echo -ne "\033[${move_cursor}B\r" + +} + +function monitor_thread_progress() { + current_thread=0 + + while : # while True + do + for i in $(seq 0 $((n_threads-1))); do + if [[ ${completed[$i]} -ne ${assigned[$i]} ]]; then + completed[$i]=$(docker logs ${container_ids[$i]} | grep -Ec "result=\S+TS_(PASS|FAIL)") # \S+ matches special characters that give color to test result + log_parallel $i $(progress_bar ${completed[$i]} ${assigned[$i]}) + fi + done + + sleep 1s + done +} + +# ================================ SYSTEM TESTS LOGIC ================================ +# Check if -t flag was specified +if [[ -n $n_threads ]]; then + parallel_setup +else + sequential_setup +fi # Usage: run_system_test [optional arguments of system_test.sh] function run_system_test() { - let "n_tests++" - - # Run in background and wait for test to finish to allow for interrupting from the terminal - ./system_test.sh $@ $n_tests $control_pub_key $control_ip $control_port "$adm_ips" $log_lvl $log_dir $repo_dir & - test_pid=$! - wait $test_pid - - if [[ $? -ne 0 ]]; then - let "n_failed++" + if [[ -n $n_threads ]]; then # Save the system test to the file corresponding to the current thread + current_thread_dir="thread$((current_thread+1))" + echo "run_system_test $@" >> $log_dir/$current_thread_dir/tests.txt + let "assigned[$current_thread]++" + let "current_thread = (current_thread+1) % n_threads" + else # Run the system test now + let "n_tests++" + + # Run in background and wait for test to finish to allow for interrupting from the terminal + ./system_test.sh $@ $n_tests $control_pub_key $control_ip $control_port "$adm_ips" $log_lvl $log_dir $repo_dir & + test_pid=$! + wait $test_pid + + if [[ $? -ne 0 ]]; then + let "n_failed++" + fi fi } @@ -250,36 +372,26 @@ function connectivity_test_logic() { fi } -cd $repo_dir/test_suite - -if [[ -n $packet_loss ]]; then - sudo ./set_packet_loss.sh $packet_loss -fi - -if [[ -n $delay ]]; then - sudo ./set_delay.sh $delay -fi - if [[ $performance == true ]]; then - echo -e "\nPerformance tests (without NAT)" - run_system_test -k bitrate -v 25,50,75,100 -d 3 -b TS_PASS_DIRECT router1-router2 : : - run_system_test -k packet_loss -v 0,1.5,3,4.5 -d 3 -b TS_PASS_DIRECT router1-router2 : : + log_sequential "\nPerformance tests (without NAT)" + run_system_test -k bitrate -v 100,200,300,400,500 -d 3 -b both TS_PASS_DIRECT router1-router2 : wg0:wg0 elif [[ -n $file ]]; then echo -e "\nTests from file: $file" - while read test_cmd; do + # Read line by line from $file (also last line which may not end with a newline, but still contain a command) + while IFS= read -r test_cmd || [ -n "$test_cmd" ]; do eval $test_cmd done < $file else rfc_3489_nats=("0-0" "0-1" "0-2" "2-2") - echo """ + log_sequential """ Starting connectivity tests between two peers (possibly) behind NATs with various combinations of mapping and filtering behaviour: - Endpoint-Independent Mapping/Filtering (EIM/EIF) - Address-Dependent Mapping/Filtering (ADM/ADF) - Address and Port-Dependent Mapping/Filtering (ADPM/ADPF)""" - echo -e "\nTests with one peer behind a NAT" + log_sequential "\nTests with one peer behind a NAT" for nat_mapping in {0..2}; do for nat_filter in {0..2}; do nat=$nat_mapping-$nat_filter @@ -291,7 +403,7 @@ Starting connectivity tests between two peers (possibly) behind NATs with variou done done - echo -e "\nTests with both peers behind a NAT" + log_sequential "\nTests with both peers behind a NAT" for nat1_mapping in {0..2}; do for nat1_filter in {0..2}; do for nat2_mapping in {0..2}; do @@ -302,7 +414,7 @@ Starting connectivity tests between two peers (possibly) behind NATs with variou done done - echo -e "\nTest hairpinning" + log_sequential "\nTest hairpinning" for nat_mapping in {0..2}; do for nat_filter in {0..2}; do nat=$nat_mapping-$nat_filter @@ -316,6 +428,9 @@ Starting connectivity tests between two peers (possibly) behind NATs with variou fi function print_summary() { + n_failed=$1 + n_tests=$2 + if [[ $n_failed -eq 0 ]]; then echo -e "${GREEN}All tests passed!${NC}" else @@ -324,7 +439,71 @@ function print_summary() { fi } -print_summary +if [[ -n $n_threads ]]; then + # Keep track of docker container IDs + container_ids=() + + docker_log_dir="/go/common/test_suite/system_test_logs" + + for i in $(seq 1 $n_threads); do + thread="thread$i" + + # Run the thread in a docker container, and store its ID + container_id=$(docker run \ + --network=host `# Host driver gives faster curl connectivity check` \ + --cap-add CAP_SYS_ADMIN --cap-add NET_ADMIN --security-opt apparmor=unconfined --device /dev/net/tun:/dev/net/tun `# Permissions required to create network setup` \ + --mount type=bind,src=$log_dir/$thread,dst=$docker_log_dir/$thread `# Bind logs inside docker container to the corresponding thread on the host` \ + -dt system_tests -f $docker_log_dir/$thread/tests.txt -L $thread `# Run tests from this thread's file and store the logs in the mounted directory` \ + $system_test_opts) # Copy the remaining options from the current system tests command + container_ids+=($container_id) + + # Print progress bar for this thread, unless test is run as GitHub Action + if [[ -z $GITHUB_ACTION ]]; then + echo -e "\tThread $i: $(progress_bar 0 ${assigned[$((i-1))]})\r" + fi + done + + # Report on progress of each thread + if [[ -z $GITHUB_ACTION ]]; then + monitor_thread_progress & + progress_pid=$! + fi + + exit_codes=( $(docker wait ${container_ids[@]}) ) # Each exit code represents the amount of failed tests in the corresponding container + + if [[ -z $GITHUB_ACTION ]]; then + kill $progress_pid # Stop monitoring progress after all containers have finished + fi + + # Print summary for each thread individually + for i in $(seq 0 $((n_threads-1))); do + test_summary=$(print_summary ${exit_codes[$i]} ${assigned[$i]}) + + if [[ -z $GITHUB_ACTION ]]; then + progress_bar=$(progress_bar ${assigned[$i]} ${assigned[$i]}) + log_parallel $i "$progress_bar - $test_summary" + else + echo -e "\tThread $((i+1)): $test_summary" + fi + done + + # Replace space delimiters by + and pipe into calculator + n_failed=$(echo ${exit_codes[@]} | tr " " + | bc) + n_tests=$(echo ${assigned[@]} | tr " " + | bc) + + # Log the containers' outputs + for i in $(seq 1 $n_threads); do + thread="thread$i" + id=${container_ids[$((i-1))]} + docker logs $id > $log_dir/$thread/cmd_output.txt + done + + # Containers are only used one time, now that they have finished running they can be removed + docker rm ${container_ids[@]} > /dev/null +else + # Create graphs for performance tests, if any were included + python3 visualize_performance_tests.py $log_dir +fi -# Create graphs for performance tests, if any were included -python3 visualize_performance_tests.py $log_dir \ No newline at end of file +print_summary $n_failed $n_tests +exit $n_failed \ No newline at end of file diff --git a/test_suite/test_client/main.go b/test_suite/test_client/main.go index e33f8c1..9464724 100644 --- a/test_suite/test_client/main.go +++ b/test_suite/test_client/main.go @@ -6,12 +6,6 @@ import ( "errors" "flag" "fmt" - "github.com/edup2p/common/ext_wg" - "github.com/edup2p/common/toversok" - "github.com/edup2p/common/types/dial" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/usrwg" - "golang.zx2c4.com/wireguard/wgctrl" "log/slog" "net/netip" "os" @@ -19,6 +13,14 @@ import ( "path/filepath" "strings" "syscall" + + "github.com/edup2p/common/extwg" + "github.com/edup2p/common/toversok" + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/dial" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/usrwg" + "golang.zx2c4.com/wireguard/wgctrl" ) // Flags @@ -72,9 +74,11 @@ func main() { flag.Parse() - var level = slog.LevelInfo + level := slog.LevelInfo switch logLevel { + case "trace": + level = types.LevelTrace case "debug": level = slog.LevelDebug case "info": @@ -94,17 +98,17 @@ func main() { if extPort < 0 || extPort > 65535 { slog.Error("external port out of range 0-65535, aborting", "ext-port", extPort) os.Exit(1) - } else { - engineExtPort = uint16(extPort) } + engineExtPort = uint16(extPort) + if controlPort < 0 || controlPort > 65535 { slog.Error("control port out of range 0-65535, aborting", "control-port", controlPort) os.Exit(1) - } else { - controlPort16 = uint16(controlPort) } + controlPort16 = uint16(controlPort) + var err error if parsedControlKey, err = parseControlKey(controlKeyStr); err != nil { @@ -123,7 +127,7 @@ func main() { } if controlHost != "" || controlPort != 0 || controlKeyStr != "" { - var mustWrite = false + mustWrite := false if config.ControlPort != controlPort16 { slog.Warn("config control port and given control port disagree, overwriting config", "config", config.ControlPort, "cli-given", controlPort16) @@ -179,7 +183,9 @@ func main() { os.Exit(1) } - if err = engine.Start(); err != nil { + var runningCtx context.Context + + if runningCtx, err = engine.Start(); err != nil { slog.Error("could not start engine", "err", err) os.Exit(1) } @@ -200,10 +206,11 @@ func main() { ccc(errors.New("interrupted")) }() - <-engine.Context().Done() + <-runningCtx.Done() + ccc(errors.New("stopping")) if !interrupted { - slog.Warn("engine exited with error", "err", engine.Context().Err()) + slog.Warn("stopped with error", "err", runningCtx.Err()) os.Exit(1) } } @@ -212,13 +219,12 @@ func parseControlKey(str string) (*key.ControlPublic, error) { if controlKeyStr == "" { return nil, nil } - - if p, err := key.UnmarshalControlPublic(str); err != nil { - + p, err := key.UnmarshalControlPublic(str) + if err != nil { return nil, fmt.Errorf("could not parse control key: %w", err) - } else { - return p, nil } + + return p, nil } func normalisePath(file string) (string, error) { @@ -249,7 +255,6 @@ func getOrGenerateConfig(file string) (*Config, error) { var c *Config data, err := os.ReadFile(file) - if err != nil { if os.IsNotExist(err) { slog.Info("config file does not exist, generating new config...", "file", file) @@ -293,7 +298,7 @@ func writeConfig(c *Config, file string) error { return fmt.Errorf("failed to marshal config: %w", err) } - if err := os.WriteFile(file, jsonData, 0644); err != nil { + if err := os.WriteFile(file, jsonData, 0o644); err != nil { return fmt.Errorf("failed to write config to file: %w", err) } @@ -302,17 +307,18 @@ func writeConfig(c *Config, file string) error { func getWireguardHost() (toversok.WireGuardHost, error) { if extWgDevice != "" { - if wg, err := getWgControl(extWgDevice); err != nil { + wg, err := getWgControl(extWgDevice) + if err != nil { return nil, fmt.Errorf("could not initialise external wireguard device: %w", err) - } else { - return wg, nil } - } else { - return usrwg.NewUsrWGHost(), nil + + return wg, nil } + + return usrwg.NewUsrWGHost(), nil } -func getWgControl(device string) (*ext_wg.WGCtrl, error) { +func getWgControl(device string) (*extwg.WGCtrl, error) { client, err := wgctrl.New() if err != nil { return nil, fmt.Errorf("could not initialise wgctrl: %w", err) @@ -322,7 +328,7 @@ func getWgControl(device string) (*ext_wg.WGCtrl, error) { return nil, fmt.Errorf("could not find/initialise wgctrl device: %w", err) } - return ext_wg.NewWGCtrl(client, device), nil + return extwg.NewWGCtrl(client, device), nil } // A dummy firewall diff --git a/test_suite/test_client/setup_client.sh b/test_suite/test_client/setup_client.sh index c2f320a..9238e23 100755 --- a/test_suite/test_client/setup_client.sh +++ b/test_suite/test_client/setup_client.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash usage_str=""" -Usage: ${0} [WIREGUARD INTERFACE] +Usage: ${0} [WIREGUARD INTERFACE] should be one of {trace|debug|info} (in order of most to least log messages), but can NOT be info if one if the peers is using userspace WireGuard (then IP of the other peer is not logged) @@ -27,8 +27,8 @@ done shift $((OPTIND-1)) # Make sure all required positional parameters have been passed -min_req=7 -max_req=8 +min_req=8 +max_req=9 if [[ $# < $min_req || $# > $max_req ]]; then exit_with_error "expected $min_req or $max_req positional parameters, but received $#" @@ -41,7 +41,8 @@ control_pub_key=$4 control_ip=$5 control_port=$6 log_lvl=$7 -wg_interface=$8 +log_dir=$8 +wg_interface=$9 # Create WireGuard interface if wg_interface is set if [[ -n $wg_interface ]]; then @@ -52,6 +53,7 @@ fi # Create temporary file to store test_client output out="test_client_out_${id}.txt" +touch $out # Run test_client and store its output in the temporary file (sudo ./test_client --control-host=$control_ip --control-port=$control_port --control-key=control:$control_pub_key --ext-wg-device=$wg_interface --log-level=$log_lvl --config=$id.json 2>&1 | tee $out &) @@ -62,10 +64,6 @@ function clean_exit() { # Remove temporary test_client output file sudo rm $out - # Remove http server output files if they exist - rm $http_ipv4_out &> /dev/null - rm $http_ipv6_out &> /dev/null - # Kill http servers if they are running kill $http_ipv4_pid &> /dev/null kill $http_ipv6_pid &> /dev/null @@ -79,7 +77,7 @@ trap "clean_exit 1" SIGTERM # Get own virtual IPs and peer's virtual IPs with external WireGuard if [[ -n $wg_interface ]]; then # Store virtual IPs as " " when they are logged - ips=$(timeout 10s tail -n +1 -f $out | sed -rn "/.*sudo ip address add (\S+) dev ${wg_interface}; sudo ip address add (\S+) dev ${wg_interface}.*/{s//\1 \2/p;q}") + ips=$(timeout 10s tail -n +1 -f -s 0.1 $out | sed -rn "/.*sudo ip address add (\S+) dev ${wg_interface}; sudo ip address add (\S+) dev ${wg_interface}.*/{s//\1 \2/p;q}") if [[ -z $ips ]]; then echo "TS_FAIL: could not find own virtual IPs in logs" @@ -103,7 +101,7 @@ if [[ -n $wg_interface ]]; then peer_ips=$(wg show $wg_interface allowed-ips | cut -d$'\t' -f2) # IPs are shown as "\t " while [[ -z $peer_ips ]]; do - sleep 1s + sleep 0.1s let "timeout--" if [[ $timeout -eq 0 ]]; then @@ -123,7 +121,7 @@ else timeout=10 while ! ip address show ts0 | grep -Eq "inet [0-9.]+"; do - sleep 1s + sleep 0.1s let "timeout--" if [[ $timeout -eq 0 ]]; then @@ -136,8 +134,8 @@ else ipv4=$(extract_ipv4 $peer_ns ts0) ipv6=$(extract_ipv6 $peer_ns ts0) - # Store peer IPs as " "" when they are logged - peer_ips=$(timeout 10s tail -n +1 -f $out | sed -rn "/.*IPv4:(\S+) IPv6:(\S+).*/{s//\1 \2/p;q}") + # Store peer IPs as " " when they are logged + peer_ips=$(timeout 10s tail -f -n +1 -s 0.1 $out | sed -rn "/.*IPv4:(\S+) IPv6:(\S+).*/{s//\1 \2/p;q}") if [[ -z $peer_ips ]]; then echo "TS_FAIL: could not find peer's virtual IPs in logs" @@ -149,15 +147,23 @@ else peer_ipv6=$(echo $peer_ips | cut -d ' ' -f2) fi +# Necessary to avoid failures with hairpinning tests, probably caused by delay in adding nftables rules to simulate hairpinning +sleep 0.5s + # Start HTTP servers on own virtual IPs for peer to access, and save their pids to kill them during cleanup -http_ipv4_out="http_ipv4_output_${id}.txt" +http_ipv4_out="$log_dir/${id}_http_ipv4.txt" python3 -m http.server -b $ipv4 80 &> $http_ipv4_out & http_ipv4_pid=$! -http_ipv6_out="http_ipv6_output_${id}.txt" +http_ipv6_out="$log_dir/${id}_http_ipv6.txt" python3 -m http.server -b $ipv6 80 &> $http_ipv6_out & http_ipv6_pid=$! +# Desynchronize peers to avoid error and subsequent recovery delay caused by handshake initation in both directions at same time +if [[ $id == "peer1" ]]; then + sleep 0.1s +fi + # Try connecting to peer's HTTP server hosted on IP addres function try_connect() { peer_addr=$1 @@ -172,14 +178,13 @@ try_connect "http://${peer_ipv4}" # Peers try to establish a direct connection after initial connection; if expecting a direct connection, give them some time to establish one if [[ $test_target == "TS_PASS_DIRECT" ]]; then - timeout 10s tail -f -n +1 $out | sed -n "/ESTABLISHED direct peer connection/q" + timeout 10s tail -f -n +1 -s 0.1 $out | sed -n "/ESTABLISHED direct peer connection/q" fi try_connect "http://[${peer_ipv6}]" -# Wait until timeout or until peer connected to server (peer's IP will appear in server output) -timeout 10s tail -f -n +1 $http_ipv4_out | sed -n "/${peer_ipv4}/q" -timeout 10s tail -f -n +1 $http_ipv6_out | sed -n "/${peer_ipv6}/q" +# Wait until timeout or until peer connected to second server (peer's IP will appear in server output) +timeout 10s tail -f -n +1 -s 0.1 $http_ipv6_out | sed -n "/${peer_ipv6}/q" echo "TS_PASS" clean_exit 0 \ No newline at end of file diff --git a/test_suite/visualize_performance_tests.py b/test_suite/visualize_performance_tests.py index fe35a3a..e6b0c5f 100644 --- a/test_suite/visualize_performance_tests.py +++ b/test_suite/visualize_performance_tests.py @@ -4,6 +4,7 @@ import re import sys import matplotlib.pyplot as plt +from itertools import groupby if len(sys.argv) != 2: print(f""" @@ -47,17 +48,41 @@ def test_iteration(): m = p.match(test_dir) test_var = m.group(1) - test_var_values, extracted_data = connection_iteration(test_path, test_var) + # Extract test variable values and corresponding measurements + test_var_values, extracted_data = repetition_iteration(test_path, test_var) + extracted_data = aggregate_repetitions(extracted_data) + + # Create dictionary containing test variable info + test_var_dict = TEST_VARS[test_var] + test_var_dict["values"] = test_var_values + + with open(f"{parent_path}/performance_test_data.json", 'w') as file: + # Delete transform key from bitrate metric, since it is not JSON serializable + del extracted_data["bitrate"]["transform"] + + # Delete json_key key from all metrics, since they are no longer needed + for k in extracted_data.keys(): + del extracted_data[k]["json_key"] + + # Merge test variable info and measurements into one dictionary + data_dict = { + "test_var": test_var_dict, + "measurements": extracted_data + } + + json.dump(data_dict, file, indent=4) for metric in extracted_data.keys(): - create_graph(test_var, test_var_values, metric, extracted_data, parent_path) + create_performance_graph(test_var, test_var_values, metric, extracted_data, parent_path) + + create_variance_grid(data_dict, parent_path) if n_tests > 0: plural = "s" if n_tests > 1 else "" print(f"Generated graphs to visualize {n_tests} performance test{plural}") -# Recursively iterate over all connection subdirectories (eduP2P/WireGuard/Direct) -def connection_iteration(test_path : str, test_var : str) -> dict: +# Recursively iterate over all repetition subdirectories +def repetition_iteration(test_path: str, test_var: str) -> tuple[list[float], dict]: extracted_data = { "bitrate" : { "label" : "Measured bitrate", @@ -86,7 +111,28 @@ def connection_iteration(test_path : str, test_var : str) -> dict: }, } - paths = Path(test_path).glob("*") + # Delay is not affected by the iperf3 target bitrate, so this data has not been measured + if(test_var == "bitrate"): + del extracted_data["delay"] + + paths = Path(test_path).rglob("repetition*") + + # Iterate over repetitions sorted from lowest to highest number (default sorting order is inconsistent) + for path in sorted(paths, key=lambda p: str(p)): + repetition_path = str(path) + repetition_id = repetition_path.split('/')[-1] + + # Initialize the dictionary of measurements for this repetition + for metric in extracted_data.keys(): + extracted_data[metric]["values"][repetition_id] = {} + + test_var_values, extracted_data = connection_iteration(repetition_path, repetition_id, test_var, extracted_data) + + return test_var_values, extracted_data + +# Recursively iterate over all connection subdirectories (eduP2P/WireGuard/Direct) +def connection_iteration(repetition_path: str, repetition_id: str, test_var: str, extracted_data: dict) -> tuple[list[float], dict]: + paths = Path(repetition_path).glob("*") for path in paths: connection_path = str(path) @@ -94,14 +140,14 @@ def connection_iteration(test_path : str, test_var : str) -> dict: # Initialize the lists of measurements for this connection type for metric in extracted_data.keys(): - extracted_data[metric]["values"][connection_type] = [] + extracted_data[metric]["values"][repetition_id][connection_type] = [] - test_var_values, extracted_data = file_iteration(connection_type, connection_path, test_var, extracted_data) + test_var_values, extracted_data = file_iteration(connection_type, connection_path, repetition_id, test_var, extracted_data) return test_var_values, extracted_data # Recursively iterate over all json files in the connection subdirectories (each file corresponds to one test variable value) -def file_iteration(connection_type : str, connection_path : str, test_var : str, extracted_data : dict) -> tuple[list[float], dict]: +def file_iteration(connection_type: str, connection_path: str, repetition_id: str, test_var: str, extracted_data: dict) -> tuple[list[float], dict]: test_var_values = [] paths = Path(connection_path).glob(f"{test_var}=*") @@ -118,18 +164,19 @@ def file_iteration(connection_type : str, connection_path : str, test_var : str, with open(path_str, 'r') as file: data = json.load(file) - extracted_data = extract_data(connection_type, data, extracted_data) + extracted_data = extract_data(connection_type, repetition_id, data, extracted_data) # Sort data sorted_indices=np.argsort(test_var_values) - test_var_values = np.array(test_var_values)[sorted_indices] + test_var_values = list(np.array(test_var_values)[sorted_indices]) for metric in extracted_data.keys(): - extracted_data[metric]["values"][connection_type] = np.array(extracted_data[metric]["values"][connection_type])[sorted_indices] + sorted_measurements = np.array(extracted_data[metric]["values"][repetition_id][connection_type])[sorted_indices] + extracted_data[metric]["values"][repetition_id][connection_type] = list(sorted_measurements) return test_var_values, extracted_data -def extract_data(connection_type : str, data : dict, extracted_data : dict) -> dict: +def extract_data(connection_type: str, repetition_id: str, data: dict, extracted_data: dict) -> dict: data = data["end"]["sum"] for metric in extracted_data.keys(): @@ -143,44 +190,116 @@ def extract_data(connection_type : str, data : dict, extracted_data : dict) -> d except KeyError: pass - extracted_data[metric]["values"][connection_type].append(measurement) + extracted_data[metric]["values"][repetition_id][connection_type].append(measurement) return extracted_data -def create_graph(test_var : str, test_var_values : list[float], metric : str, extracted_data : dict, save_path : str): - metric_data = extracted_data[metric] - connection_measurements = metric_data["values"] +def aggregate_repetitions(extracted_data: dict) -> dict: + for metric in extracted_data.keys(): + measurements = extracted_data[metric]["values"] + connection_dicts = [v for _, v in measurements.items()] + connection_measurement_pairs = [(k, v) for c in connection_dicts for k, v in list(c.items())] + aggregated_measurements = {} + + # Group the measurements by connection type to compute the average + for connection_type, groups in groupby(sorted(connection_measurement_pairs, reverse=True), key=lambda t: t[0]): + aggregated_measurements[connection_type] = list(np.array([group[1] for group in groups]).mean(axis=0)) + + measurements["average"] = aggregated_measurements + + return extracted_data +# Given a dictionary containing the label and unit of a metric, returns a string to describe the metric on a graph axis +def axis_label(label_unit_dict: dict) -> str: + return f"{label_unit_dict['label']} ({label_unit_dict['unit']})" + +# Graph to illustrate the performance of eduP2P, possibly by comparing against WireGuard and/or a direct connection +def create_performance_graph(test_var: str, test_var_values: list[float], metric: str, extracted_data: dict, save_path: str): + metric_data = extracted_data[metric] + connection_measurements = metric_data["values"]["average"] test_var_label = TEST_VARS[test_var]["label"] - test_var_unit = TEST_VARS[test_var]["unit"] metric_label = metric_data["label"] - metric_unit = metric_data["unit"] # Different line styles in case they overlap line_styles=["-", "--", ":"] line_widths=[4,3,2] for i, connection in enumerate(connection_measurements.keys()): + x = test_var_values y = connection_measurements[connection] ls=line_styles[i] lw=line_widths[i] + plt.plot(x, y, linestyle=ls, linewidth=lw, label=connection) - # Plot the measured independent variable values on the X axis instead of the target values, unless the measured values or the delay are plotted on the Y axis (delay is not affected by the iperf3 measured values) - if metric == test_var or metric == "delay": - plt.plot(test_var_values, y, linestyle=ls, linewidth=lw, label=connection) - x_label = test_var_label - else: - measured_test_var_values = sorted(extracted_data[test_var]["values"][connection]) - plt.plot(measured_test_var_values, y, linestyle=ls, linewidth=lw, label=connection) - x_label = extracted_data[test_var]["label"] - - plt.xlabel(f"{x_label} ({test_var_unit})") - plt.ylabel(f"{metric_label} ({metric_unit})") - plt.title(f"{metric_label} for varying {x_label}") + plt.xlabel(axis_label(TEST_VARS[test_var])) + plt.ylabel(axis_label(metric_data)) + plt.title(f"{metric_label} for varying {test_var_label}") plt.ticklabel_format(useOffset=False) plt.legend() - + plt.tight_layout() plt.savefig(f"{save_path}/performance_test_{metric}.png") plt.clf() +# Create an * grid of plots showing the variance in measurements across repetitions for each metric and connection type +def create_variance_grid(data_dict: dict, save_path: str): + test_var_info = data_dict["test_var"] + test_var_values = test_var_info["values"] + + measurements = data_dict["measurements"] + metrics = list(measurements.keys()) + n_metrics = len(metrics) + reps_and_avg = measurements[metrics[0]]["values"] + + # This indicates that reps_and_avg = ["repetition1", "average"], so only 1 repetition is performed + if len(reps_and_avg) == 2: + return + + connections = list(reps_and_avg["average"].keys()) + n_connections = len(connections) + fig, ax = plt.subplots(n_metrics, n_connections) + + # Iterate over rows in the grid + for i, metric in enumerate(metrics): + metric_dict = measurements[metric] + create_variance_col(ax[i], i, test_var_values, metric_dict["values"]) + + # Y label is the same for each row, so we only set it on the first column to save space + ax[i][0].set_ylabel(axis_label(metric_dict)) + + # X label is the same for each subplot, so we only set it on the bottom row to save space + for j in range(n_connections): + ax[n_metrics-1][j].set_xlabel(axis_label(test_var_info)) + + # Display legend shared between the subplots and save the figure + handles, labels = ax[0][0].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.03), ncol=len(reps_and_avg)//2) + + # Place suptitle higher to free up space for legend + fig.suptitle("Variance of measurements over multiple repetitions", y=1.06) + + subplot_size = 4 + fig.set_figheight(n_metrics * subplot_size) + fig.set_figwidth(n_connections * subplot_size) + fig.tight_layout() + fig.savefig(f"{save_path}/performance_test_variance.png", bbox_inches="tight") # bbox_inches prevents suptitle and legend from being cropped + +# Fill one column of the variance grid with graphs +def create_variance_col(ax: np.ndarray[plt.Axes], i: int, test_var_values: list[float], measurements: dict): + for k, repetition_dict in measurements.items(): + for j, conn in enumerate(repetition_dict.keys()): + conn_measurements = repetition_dict[conn] + + # Make line representing the average stand out from lines representing individual repetitions + if k == "average": + ax[j].plot(test_var_values, conn_measurements, label=k, linestyle="-", linewidth=3, color="black") + else: + ax[j].plot(test_var_values, conn_measurements, label=k, linestyle="--", linewidth=1.5) + + # Each connection type takes up a separate column, so put the connection type above the top subplots + if(i == 0): + ax[j].set_title(conn) + + # On the Y axis, use scientific notation for numbers outside the range [1e-3, 1e4] to prevent them from crossing into other subplots + ax[j].ticklabel_format(axis='y', style='sci', scilimits=(-3,4)) + test_iteration() \ No newline at end of file diff --git a/toversok/actors/a_conn.go b/toversok/actors/a_conn.go index d4a7c92..e92cb2e 100644 --- a/toversok/actors/a_conn.go +++ b/toversok/actors/a_conn.go @@ -3,12 +3,15 @@ package actors import ( "context" "errors" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgactor" + "log/slog" "net" "net/netip" + "runtime/debug" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgactor" ) type OutConn struct { @@ -37,10 +40,10 @@ func MakeOutConn(udp types.UDPConn, peer key.NodePublic, homeRelay int64, s *Sta common := MakeCommon(s.Ctx, OutConnInboxChanBuffer) - return &OutConn{ + return assureClose(&OutConn{ ActorCommon: common, - sock: MakeSockRecv(udp, common.ctx), + sock: MakeSockRecv(common.ctx, udp), s: s, peer: peer, @@ -49,32 +52,31 @@ func MakeOutConn(udp types.UDPConn, peer key.NodePublic, homeRelay int64, s *Sta activityTimer: t, isActive: false, - } + }) } func (oc *OutConn) Run() { + if !oc.running.CheckOrMark() { + L(oc).Warn("tried to run agent, while already running") + return + } + + defer oc.Cancel() defer func() { if v := recover(); v != nil { - L(oc).Error("panicked", "panic", v) - oc.Cancel() + L(oc).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(oc.ctx, v) } }() - if !oc.running.CheckOrMark() { - L(oc).Warn("tried to run agent, while already running") - return - } - go oc.sock.Run() for { select { case <-oc.ctx.Done(): - oc.Close() return case <-oc.sock.ctx.Done(): - oc.Cancel() + return case <-oc.activityTimer.C: oc.UnBump() case msg := <-oc.inbox: @@ -96,8 +98,7 @@ func (oc *OutConn) Run() { // sock closed, the peer is dead // TODO: // trigger some kind of healing logic elsewhere? - oc.Cancel() - continue + return } if oc.useRelay { @@ -199,7 +200,7 @@ func MakeInConn(udp types.UDPConn, peer key.NodePublic, s *Stage) *InConn { t := time.NewTimer(60 * time.Second) t.Stop() - return &InConn{ + return assureClose(&InConn{ ActorCommon: MakeCommon(s.Ctx, -1), s: s, @@ -211,28 +212,26 @@ func MakeInConn(udp types.UDPConn, peer key.NodePublic, s *Stage) *InConn { pktCh: make(chan []byte, InConnFrameChanBuffer), peer: peer, - } + }) } func (ic *InConn) Run() { + if !ic.running.CheckOrMark() { + L(ic).Warn("tried to run agent, while already running") + return + } + + defer ic.Cancel() defer func() { if v := recover(); v != nil { - L(ic).Error("panicked", "panic", v) - ic.Cancel() - ic.Close() + L(ic).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(ic.ctx, v) } }() - if !ic.running.CheckOrMark() { - L(ic).Warn("tried to run agent, while already running") - return - } - for { select { case <-ic.ctx.Done(): - ic.Close() return case <-ic.activityTimer.C: ic.UnBump() @@ -240,7 +239,6 @@ func (ic *InConn) Run() { n, err := ic.udp.Write(frame) if err != nil { if errors.Is(err, net.ErrClosed) { - ic.Cancel() return } // TODO failsafe logic @@ -263,6 +261,9 @@ func (ic *InConn) Close() { Peer: ic.peer, IsIn: false, } + if err := ic.udp.Close(); err != nil { + slog.Error("failed to close inconn udp", "peer", ic.peer, "err", err) + } } func (ic *InConn) Ctx() context.Context { diff --git a/toversok/actors/a_conn_test.go b/toversok/actors/a_conn_test.go index 6e9dd0d..25aad77 100644 --- a/toversok/actors/a_conn_test.go +++ b/toversok/actors/a_conn_test.go @@ -27,10 +27,10 @@ func TestOutConn(t *testing.T) { wgConn := &MockUDPConn{ writeCh: make(chan []byte), - setReadDeadline: func(t time.Time) error { + setReadDeadline: func(time.Time) error { return nil }, - readFromUDPAddrPort: func(b []byte) (n int, addr netip.AddrPort, err error) { + readFromUDPAddrPort: func([]byte) (n int, addr netip.AddrPort, err error) { return 0, dummyAddrPort, nil }, } diff --git a/toversok/actors/a_direct.go b/toversok/actors/a_direct.go index 04ab40e..d7b02ad 100644 --- a/toversok/actors/a_direct.go +++ b/toversok/actors/a_direct.go @@ -2,15 +2,17 @@ package actors import ( "context" + "log/slog" + "net/netip" + "runtime" + "runtime/debug" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgsess" "golang.org/x/exp/maps" - "log/slog" - "net/netip" - "runtime" ) type directWriteRequest struct { @@ -29,28 +31,28 @@ type DirectManager struct { func (s *Stage) makeDM(udpSocket types.UDPConn) *DirectManager { c := MakeCommon(s.Ctx, -1) - return &DirectManager{ + return assureClose(&DirectManager{ ActorCommon: c, - sock: MakeSockRecv(udpSocket, c.ctx), + sock: MakeSockRecv(c.ctx, udpSocket), s: s, writeCh: make(chan directWriteRequest, DirectManWriteChLen), - } + }) } func (dm *DirectManager) Run() { + if !dm.running.CheckOrMark() { + L(dm).Warn("tried to run agent, while already running") + return + } + + defer dm.Cancel() defer func() { if v := recover(); v != nil { - L(dm).Error("panicked", "panic", v) - dm.Cancel() + L(dm).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(dm.ctx, v) } }() - if !dm.running.CheckOrMark() { - L(dm).Warn("tried to run agent, while already running") - return - } - go dm.sock.Run() runtime.LockOSThread() @@ -58,15 +60,7 @@ func (dm *DirectManager) Run() { for { select { case <-dm.ctx.Done(): - dm.Close() return - //case msg := <-dm.inbox: - // switch m := msg.(type) { - // case *DManSetMTU: - // dm.SetMTUFor(m.forAddrPort, m.mtu) - // default: - // dm.logUnknownMessage(m) - // } case req := <-dm.writeCh: L(dm).Log(context.Background(), types.LevelTrace, "direct: writing") _, err := dm.sock.Conn.WriteToUDPAddrPort(req.pkt, req.to) @@ -103,25 +97,7 @@ func (dm *DirectManager) WriteTo(pkt []byte, addr netip.AddrPort) { } } -//// MTUFor gets the MTU for a netip.AddrPort pair, or default. -//func (dm *DirectManager) MTUFor(ap netip.AddrPort) uint16 { -// // TODO(jo): there is a small possibility that internal representation in -// // netip.AddrPort can differ, even though they'd be the same IP+Port pair. -// // I haven't found such a case, but it'S nagging in the back of my mind, -// // which is why this is a separate function, -// // so we can do any canonisation later. -// mtu, ok := dm.mtuFor[ap] -// if !ok { -// return DefaultSafeMTU -// } else { -// return mtu -// } -//} -// -//// SetMTUFor sets the MTU for a netip.AddrPort pair. -//func (dm *DirectManager) SetMTUFor(ap netip.AddrPort, mtu uint16) { -// dm.mtuFor[ap] = mtu -//} +// TODO: track when we last received a packet from AddrPair? type DirectRouter struct { *ActorCommon @@ -138,40 +114,40 @@ type DirectRouter struct { } func (s *Stage) makeDR() *DirectRouter { - return &DirectRouter{ + return assureClose(&DirectRouter{ ActorCommon: MakeCommon(s.Ctx, DirectRouterInboxChLen), s: s, aka: make(map[netip.AddrPort]key.NodePublic), stunEndpoints: make(map[netip.AddrPort]bool), frameCh: make(chan ifaces.DirectedPeerFrame, DirectRouterFrameChLen), - } + }) } func (dr *DirectRouter) Push(frame ifaces.DirectedPeerFrame) { - //go func() { + // go func() { dr.frameCh <- frame - //}() + // }() } func (dr *DirectRouter) Run() { - defer func() { - if v := recover(); v != nil { - // TODO logging - dr.Cancel() - } - }() - if !dr.running.CheckOrMark() { L(dr).Warn("tried to run agent, while already running") return } + defer dr.Cancel() + defer func() { + if v := recover(); v != nil { + L(dr).Error("panicked", "panic", v, "stack", string(debug.Stack())) + bail(dr.ctx, v) + } + }() + runtime.LockOSThread() for { select { case <-dr.ctx.Done(): - dr.Close() return case m := <-dr.inbox: switch m := m.(type) { @@ -230,7 +206,7 @@ func (dr *DirectRouter) peerAKA(ap netip.AddrPort) (peer key.NodePublic, ok bool peer, ok = dr.aka[nap] - //slog.Debug("dr: peerAKA", "ap", ap.String(), "nap", nap, "ok", ok) + slog.Log(context.Background(), types.LevelTrace, "dr: peerAKA", "ap", ap.String(), "nap", nap, "ok", ok) return } diff --git a/toversok/actors/a_direct_test.go b/toversok/actors/a_direct_test.go index c87c728..5fdf3b5 100644 --- a/toversok/actors/a_direct_test.go +++ b/toversok/actors/a_direct_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "slices" "testing" "time" @@ -22,13 +23,13 @@ func TestDirectManager(t *testing.T) { mockUDPConn := &MockUDPConn{ writeCh: make(chan []byte), - setReadDeadline: func(t time.Time) error { + setReadDeadline: func(time.Time) error { return nil }, - readFromUDPAddrPort: func(b []byte) (n int, addr netip.AddrPort, err error) { + readFromUDPAddrPort: func([]byte) (n int, addr netip.AddrPort, err error) { return 0, dummyAddrPort, nil }, - writeToUDPAddrPort: func(b []byte, addr netip.AddrPort) (int, error) { + writeToUDPAddrPort: func([]byte, netip.AddrPort) (int, error) { return 0, nil }, } @@ -111,7 +112,7 @@ func TestDirectRouter(t *testing.T) { assert.Equal(t, msgEM, &msgactor.EManSTUNResponse{Endpoint: frameEndpoint.SrcAddrPort, Packet: frameEndpoint.Pkt}, "EndpointManager did not receive the expected message") // Message that should be sent to SessionManager - sessionPkt := append(msgsess.MagicBytes, zeroBytes(56)...) + sessionPkt := slices.Concat(msgsess.MagicBytes, zeroBytes(56)) frameSession := ifaces.DirectedPeerFrame{ SrcAddrPort: dummyAddrPort, @@ -125,11 +126,11 @@ func TestDirectRouter(t *testing.T) { // For each peer: register peer in DirectRouter and Stage, and then send a message to their inConn for i, b := range []byte{1, 2} { ic := ics[i] - key := ic.peer + peer := ic.peer endpoint := peerEndpoints[i] - dr.setAKA(endpoint, key) - s.inConn[key] = ic + dr.setAKA(endpoint, peer) + s.inConn[peer] = ic frame := ifaces.DirectedPeerFrame{ SrcAddrPort: endpoint, diff --git a/toversok/actors/a_eman.go b/toversok/actors/a_eman.go index 4b5f934..a6d0963 100644 --- a/toversok/actors/a_eman.go +++ b/toversok/actors/a_eman.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "runtime/debug" "slices" "time" @@ -16,14 +17,14 @@ import ( ) func (s *Stage) makeEM() *EndpointManager { - em := &EndpointManager{ + em := assureClose(&EndpointManager{ ActorCommon: MakeCommon(s.Ctx, SessManInboxChLen), s: s, ticker: time.NewTicker(EManTickerInterval), stunTimeout: time.NewTimer(EManStunTimeout), relays: make(map[int64]relay.Information), - } + }) em.stunTimeout.Stop() @@ -61,36 +62,21 @@ type stunResponse struct { latency time.Duration } -// TODO -// - receive relay info -// - do STUN requests to each and resolve remote endpoints -// - maybe determine when symmetric nat / "varies" is happening -// - do latency determination -// - inform relay manager that results are ready -// - relay manager switches home relay and informs stage of that decision -// - collect local endpoints - // TODO future: // - UPnP? Other stuff? func (em *EndpointManager) Run() { + defer em.Cancel() defer func() { if v := recover(); v != nil { - L(em).Error("panicked", "panic", v) - em.Cancel() + L(em).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(em.ctx, v) } }() - if !em.running.CheckOrMark() { - L(em).Warn("tried to run agent, while already running") - return - } - for { select { case <-em.ctx.Done(): - em.Close() return case <-em.ticker.C: em.startSTUN() @@ -131,7 +117,7 @@ func (em *EndpointManager) startSTUN() { em.collectedResponse = make([]stunResponse, 0) - var stunReq = &msgactor.DRouterPushSTUN{Packets: make(map[netip.AddrPort][]byte)} + stunReq := &msgactor.DRouterPushSTUN{Packets: make(map[netip.AddrPort][]byte)} em.stunRequests = make(map[netip.AddrPort]stunRequest) @@ -266,96 +252,6 @@ func (em *EndpointManager) endpointToRelay(ap netip.AddrPort) *int64 { return nil } -//func (em *EndpointManager) updateEndpoints() { -// ep, err := em.doSTUN(EManStunTimeout) -// if err != nil { -// if ep != nil && len(ep) > 1 { -// L(em).Warn("STUN completed with error", "endpoints", ep, "err", err) -// } else { -// L(em).Warn("STUN failed with error", "err", err) -// } -// } -// if ep != nil && len(ep) > 1 { -// em.s.setSTUNEndpoints(ep) -// L(em).Info("STUN completed", "endpoints", ep) -// } else { -// L(em).Warn("STUN completed with no endpoints") -// } -//} - -//// Performs STUN on all known servers, returns all (deduplicated) results, and any error (if there is one). -//func (em *EndpointManager) doSTUN(timeout time.Duration) (responses []netip.AddrPort, err error) { -// var c *net.UDPConn -// -// c, err = net.ListenUDP("udp", nil) -// if err != nil { -// return nil, fmt.Errorf("failed to open UDP socket: %w", err) -// } -// -// requests := make(map[netip.AddrPort]stun.TxID) -// -// for _, ep := range em.collectSTUNEndpoints() { -// txID := stun.NewTxID() -// req := stun.Request(txID) -// -// _, err = c.WriteToUDP(req, net.UDPAddrFromAddrPort(ep)) -// if err != nil { -// return nil, fmt.Errorf("failed to write to %s: %w", ep, err) -// } -// -// requests[ep] = txID -// } -// -// if err := c.SetReadDeadline(time.Now().Add(timeout)); err != nil { -// return nil, fmt.Errorf("failed to set read deadline: %w", err) -// } -// -// var responseMap = make(map[netip.AddrPort]bool) -// -// for { -// if len(requests) == 0 { -// break -// } -// -// var buf [1024]byte -// var n int -// var raddr netip.AddrPort -// -// n, raddr, err = c.ReadFromUDPAddrPort(buf[:]) -// if err != nil { -// break -// } -// -// if raddr.Addr().Is4In6() { -// raddr = netip.AddrPortFrom(netip.AddrFrom4(raddr.Addr().As4()), raddr.Port()) -// } -// -// if _, ok := requests[raddr]; !ok { -// L(em).Warn("got response from unexpected raddr while doing STUN", "raddr", raddr) -// continue -// } -// -// tid, saddr, err := stun.ParseResponse(buf[:n]) -// if err != nil { -// L(em).Warn("got error when parsing STUN response from raddr", "raddr", raddr, "err", err) -// continue -// } -// if tid != requests[raddr] { -// L(em).Warn("received different TXID from raddr than expected", "raddr", raddr, "txid.expected", requests[raddr], "txid.got", tid) -// continue -// } -// -// responseMap[saddr] = true -// delete(requests, raddr) -// } -// -// for ep := range responseMap { -// responses = append(responses, ep) -// } -// -// return responses, err -//} - // Collects STUN endpoints from known relay definitions and Control itself func (em *EndpointManager) collectRelaySTUNEndpoints() map[netip.AddrPort]int64 { relayEndpoints := make(map[netip.AddrPort]int64) @@ -370,8 +266,6 @@ func (em *EndpointManager) collectRelaySTUNEndpoints() map[netip.AddrPort]int64 } func (em *EndpointManager) getLocalEndpoints() { - // TODO disregard own address, obviously - localEndpoints := em.collectLocalEndpoints() L(em).Debug("local endpoints collected", "endpoints", localEndpoints) @@ -392,6 +286,12 @@ func (em *EndpointManager) collectLocalEndpoints() []netip.Addr { // handle err for _, i := range ifaces { + + if i.Flags&net.FlagUp == 0 || i.Flags&net.FlagPointToPoint != 0 { + // Skip interfaces that are down, or are also PPP (such as tailscale) + continue + } + addrs, err := i.Addrs() if err != nil { L(em).Warn("collectLocalEndpoints: could not get addresses from interface", "error", err, "iface", i.Name) @@ -424,6 +324,6 @@ func (em *EndpointManager) collectLocalEndpoints() []netip.Addr { } func (em *EndpointManager) Close() { - //TODO implement me - panic("implement me") + em.ticker.Stop() + em.stunTimeout.Stop() } diff --git a/toversok/actors/a_eman_test.go b/toversok/actors/a_eman_test.go index 957c7b3..f511f9f 100644 --- a/toversok/actors/a_eman_test.go +++ b/toversok/actors/a_eman_test.go @@ -20,8 +20,9 @@ type MockControl struct { controlKey func() key.ControlPublic - ipv4 func() netip.Prefix - ipv6 func() netip.Prefix + ipv4 func() netip.Prefix + ipv6 func() netip.Prefix + expiry func() time.Time updateEndpoints func([]netip.AddrPort) error updateHomeRelay func(int64) error @@ -39,6 +40,10 @@ func (m *MockControl) IPv6() netip.Prefix { return m.ipv6() } +func (m *MockControl) Expiry() time.Time { + return m.expiry() +} + func (m *MockControl) UpdateEndpoints(endpoints []netip.AddrPort) error { m.endpoints = endpoints return m.updateEndpoints(endpoints) @@ -56,7 +61,7 @@ func TestEndpointManager(t *testing.T) { mockControl := &MockControl{ endpoints: make([]netip.AddrPort, 0), - updateEndpoints: func(eps []netip.AddrPort) error { return nil }, + updateEndpoints: func([]netip.AddrPort) error { return nil }, } s.control = mockControl diff --git a/toversok/actors/a_mman.go b/toversok/actors/a_mman.go new file mode 100644 index 0000000..d0b74f6 --- /dev/null +++ b/toversok/actors/a_mman.go @@ -0,0 +1,521 @@ +package actors + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "runtime/debug" + "slices" + "syscall" + "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/msgactor" + "github.com/sethvargo/go-limiter" + "github.com/sethvargo/go-limiter/memorystore" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type MDNSManager struct { + *ActorCommon + s *Stage + + rlStore limiter.Store + + b4Sock *SockRecv + b6Sock *SockRecv + + u4Sock *SockRecv + u6Sock *SockRecv + + working bool +} + +func (s *Stage) makeMM() *MDNSManager { + c := MakeCommon(s.Ctx, MdnsManInboxChLen) + + store, err := memorystore.New(&memorystore.Config{ + // Number of tokens allowed per interval. + Tokens: 1, + + // Interval until tokens reset. + Interval: 20 * time.Second, + + SweepInterval: 1 * time.Minute, + SweepMinTTL: 1 * time.Minute, + }) + if err != nil { + // memorystore does not return an error, so this is unexpected + panic(err) + } + + m := assureClose(&MDNSManager{ + ActorCommon: c, + s: s, + rlStore: store, + }) + + b4bind, err := m.makeMDNSv4Listener() + if err != nil { + L(m).Warn("MDNS ipv4 listener creation failed", "err", err) + } else { + m.b4Sock = MakeSockRecv(c.ctx, b4bind) + } + + b6bind, err := m.makeMDNSv6Listener() + if err != nil { + L(m).Warn("MDNS ipv6 listener creation failed", "err", err) + } else { + m.b6Sock = MakeSockRecv(c.ctx, b6bind) + } + + if m.b4Sock == nil && m.b6Sock == nil { + L(m).Error("could not start MDNS Manager; creating both MDNS broadcast sockets failed") + + return m + } + + u4bind, err := m.makeIPv4UnicastListener() + if err != nil { + L(m).Warn("MDNS ipv4 sender creation failed", "err", err) + } else { + m.u4Sock = MakeSockRecv(c.ctx, u4bind) + } + + u6bind, err := m.makeIPv6UnicastListener() + if err != nil { + L(m).Warn("MDNS ipv4 sender creation failed", "err", err) + } else { + m.u6Sock = MakeSockRecv(c.ctx, u6bind) + } + + if m.u4Sock == nil && m.u6Sock == nil { + L(m).Error("could not start MDNS Manager; creating both MDNS unicast sockets failed") + + return m + } + + m.working = true + + return m +} + +var ( + MDNSPort uint16 = 5353 + ip4MDNSBroadcastBare = netip.MustParseAddr("224.0.0.251") + ip6MDNSBroadcastBare = netip.MustParseAddr("ff02::fb") + + ip4MDNSBroadcastAP = netip.AddrPortFrom(ip4MDNSBroadcastBare, MDNSPort) + ip6MDNSBroadcastAP = netip.AddrPortFrom(ip6MDNSBroadcastBare, MDNSPort) + + ip4MDNSLoopBackAP = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), MDNSPort) + ip6MDNSLoopBackAP = netip.AddrPortFrom(netip.IPv6Loopback(), MDNSPort) +) + +func getLoopBackInterface() (*net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("could not list network interfaces: %w", err) + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback != 0 { + return &iface, nil + } + } + + return nil, fmt.Errorf("no loopback interface found") +} + +func (mm *MDNSManager) makeMDNSv4Listener() (types.UDPConn, error) { + ua := net.UDPAddrFromAddrPort(ip4MDNSBroadcastAP) + + conn, err := net.ListenUDP("udp4", ua) + if err != nil { + return nil, fmt.Errorf("ListenUDP error: %w", err) + } + + pc4 := ipv4.NewPacketConn(conn) + + ift, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("cannot get interfaces: %w", err) + } + for _, ifi := range ift { + if ifi.Flags&net.FlagUp != 0 && ifi.Flags&net.FlagPointToPoint == 0 { + if err := pc4.JoinGroup(&ifi, &net.UDPAddr{IP: ip4MDNSBroadcastBare.AsSlice()}); err != nil { + L(mm).Warn("pc4 Multicast JoinGroup failed", "err", err, "iface", ifi.Name) + } + } + } + + if loop, err := pc4.MulticastLoopback(); err == nil { + if !loop { + if err := pc4.SetMulticastLoopback(true); err != nil { + return nil, fmt.Errorf("cannot set multicast loopback: %w", err) + } + } + } + + lo, err := getLoopBackInterface() + if err != nil { + return nil, fmt.Errorf("cannot get loopback interface: %w", err) + } + + if err := pc4.SetMulticastInterface(lo); err != nil { + return nil, fmt.Errorf("cannot set multicast interface: %w", err) + } + + if err := pc4.SetTTL(255); err != nil { + return nil, fmt.Errorf("cannot set TTL: %w", err) + } + if err := pc4.SetMulticastTTL(255); err != nil { + return nil, fmt.Errorf("cannot set Multicast TTL: %w", err) + } + + return conn, nil +} + +func (mm *MDNSManager) makeMDNSv6Listener() (types.UDPConn, error) { + ua := net.UDPAddrFromAddrPort(ip6MDNSBroadcastAP) + + conn, err := net.ListenUDP("udp6", ua) + if err != nil { + return nil, fmt.Errorf("ListenUDP error: %w", err) + } + + pc6 := ipv6.NewPacketConn(conn) + + ift, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("cannot get interfaces: %w", err) + } + for _, ifi := range ift { + if ifi.Flags&net.FlagUp != 0 && ifi.Flags&net.FlagPointToPoint == 0 { + if err := pc6.JoinGroup(&ifi, &net.UDPAddr{IP: ip6MDNSBroadcastBare.AsSlice()}); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + L(mm).Warn("pc6 Multicast JoinGroup failed", "err", err, "iface", ifi.Name) + } + } + } + + if loop, err := pc6.MulticastLoopback(); err == nil { + if !loop { + if err := pc6.SetMulticastLoopback(true); err != nil { + return nil, fmt.Errorf("cannot set multicast loopback: %w", err) + } + } + } + + lo, err := getLoopBackInterface() + if err != nil { + return nil, fmt.Errorf("cannot get loopback interface: %w", err) + } + + if err := pc6.SetMulticastInterface(lo); err != nil { + return nil, fmt.Errorf("cannot set multicast interface: %w", err) + } + + return conn, nil +} + +func (mm *MDNSManager) makeIPv4UnicastListener() (types.UDPConn, error) { + var laddr *net.UDPAddr + addr := ip4MDNSLoopBackAP + + if runtime.GOOS == "windows" { + laddr = net.UDPAddrFromAddrPort( + netip.AddrPortFrom(mm.s.control.IPv4().Addr(), 0), + ) + addr = ip4MDNSBroadcastAP + } + + return net.DialUDP("udp4", laddr, net.UDPAddrFromAddrPort(addr)) +} + +func (mm *MDNSManager) makeIPv6UnicastListener() (types.UDPConn, error) { + var laddr *net.UDPAddr + addr := ip6MDNSLoopBackAP + + if runtime.GOOS == "windows" { + laddr = net.UDPAddrFromAddrPort( + netip.AddrPortFrom(mm.s.control.IPv6().Addr(), 0), + ) + addr = ip6MDNSBroadcastAP + } + + return net.DialUDP("udp6", laddr, net.UDPAddrFromAddrPort(addr)) +} + +func dataToB64Hash(b []byte) string { + h := sha256.Sum256(b) + + return base64.StdEncoding.EncodeToString(h[:]) +} + +func (mm *MDNSManager) Run() { + if !mm.running.CheckOrMark() { + L(mm).Warn("tried to run agent, while already running") + return + } + + defer mm.Cancel() + defer func() { + if v := recover(); v != nil { + L(mm).Error("panicked", "panic", v, "stack", string(debug.Stack())) + bail(mm.ctx, v) + } + }() + + if !mm.working { + mm.deadRun() + return + } + + go mm.b4Sock.Run() + go mm.u4Sock.Run() + + for { + select { + case msg := <-mm.inbox: + // got MDNS message from external; inject + switch msg := msg.(type) { + case *msgactor.MManReceivedPacket: + mm.handleReceivedPacket(msg) + default: + mm.logUnknownMessage(msg) + } + case frame := <-mm.b4Sock.outCh: + mm.handleSystemFrame(frame) + case frame := <-mm.b6Sock.outCh: + mm.handleSystemFrame(frame) + case frame := <-mm.u4Sock.outCh: + mm.handleSystemFrame(frame) + case frame := <-mm.u6Sock.outCh: + mm.handleSystemFrame(frame) + case <-mm.ctx.Done(): + return + } + } +} + +func (mm *MDNSManager) handleReceivedPacket(msg *msgactor.MManReceivedPacket) { + pi := mm.s.GetPeerInfo(msg.From) + if pi == nil { + L(mm).Warn("ignoring MDNS packet due to nonexistent peerinfo", "from", msg.From.Debug()) + return + } + + extra := "ip4" + if msg.IP6 { + extra = "ip6" + } + + if _, _, _, ok, _ := mm.rlStore.Take(context.Background(), dataToB64Hash(msg.Data)+extra); !ok { + // some rudimentary filtering to prevent true loop storms + return + } + + L(mm).Debug("processing external MDNS packet", "len", len(msg.Data), "from", msg.From.Debug()) + + pkt := mm.processMDNS(msg.Data, false) + + var err error + + // TODO process external mDNS packet + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + // On macOS, we can't use the broadsock's WriteTo, since it just doesn't generate a packet. + // However, we can use our specialised query sock to poke responses in unicast, even if they're QM. + if msg.IP6 { + _, err = mm.u6Sock.Conn.Write(pkt) + } else { + _, err = mm.u4Sock.Conn.Write(pkt) + } + } else { + if msg.IP6 { + _, err = mm.b6Sock.Conn.WriteToUDPAddrPort(pkt, ip6MDNSBroadcastAP) + } else { + _, err = mm.b4Sock.Conn.WriteToUDPAddrPort(pkt, ip4MDNSBroadcastAP) + } + } + if err != nil { + L(mm).Warn("failed to process external MDNS packet", "err", err) + } +} + +func (mm *MDNSManager) handleSystemFrame(frame RecvFrame) { + // got MDNS message from system; forward + + nap := types.NormaliseAddr(frame.src.Addr()) + + extra := "ip4" + if nap.Is6() { + extra = "ip6" + } + + if _, _, _, ok, _ := mm.rlStore.Take(context.Background(), dataToB64Hash(frame.pkt)+extra); !ok { + // some rudimentary filtering to prevent true loop storms + return + } + + if !mm.isSelf(nap) { + L(mm).Log(context.Background(), types.LevelTrace, "dropping mDNS packet due to non-local origin", "from", frame.src) + return + } + + // TODO proper in-depth filtering + + L(mm).Debug("spreading local MDNS packet to peers", "len", len(frame.pkt), "from", frame.src.String()) + + pkt := mm.processMDNS(frame.pkt, true) + + SendMessage(mm.s.TMan.Inbox(), &msgactor.TManSpreadMDNSPacket{Pkt: pkt, IP6: nap.Is6()}) +} + +func (mm *MDNSManager) debugMDNS(msg *dnsmessage.Message) { + L(mm).Debug("debugMDNS: TXID", "txid", msg.ID) + + for _, q := range msg.Questions { + L(mm).Debug( + "debugMDNS: Q", + "txid", msg.ID, + "name", q.Name, + "type", q.Type.GoString(), + "class", q.Class.GoString(), + ) + } + for _, a := range msg.Answers { + L(mm).Debug( + "debugMDNS: A", + "txid", msg.ID, + "header", a.Header.GoString(), + "body", a.Body.GoString(), + ) + } +} + +func (mm *MDNSManager) fixResource(res *dnsmessage.Resource) (dirty bool) { + switch res.Header.Type { + case dnsmessage.TypeA: + ar := res.Body.(*dnsmessage.AResource) + if mm.isLocal(netip.AddrFrom4(ar.A)) { + ar.A = mm.s.control.IPv4().Addr().As4() + res.Header.Class |= 32768 + dirty = true + } + case dnsmessage.TypeAAAA: + a4r := res.Body.(*dnsmessage.AAAAResource) + + if mm.isLocal(netip.AddrFrom16(a4r.AAAA)) { + a4r.AAAA = mm.s.control.IPv6().Addr().As16() + res.Header.Class |= 32768 + dirty = true + } + } + + return +} + +func (mm *MDNSManager) isLocal(addr netip.Addr) bool { + return addr.IsLoopback() || slices.IndexFunc(mm.s.getLocalEndpoints(), func(cAddr netip.Addr) bool { + return cAddr == addr + }) != -1 +} + +func (mm *MDNSManager) isSelf(addr netip.Addr) bool { + return mm.isLocal(addr) || addr == mm.s.control.IPv4().Addr() || addr == mm.s.control.IPv6().Addr() +} + +func (mm *MDNSManager) processMDNS(pkt []byte, local bool) []byte { + msg := dnsmessage.Message{} + if err := msg.Unpack(pkt); err != nil { + L(mm).Warn("failed to unpack MDNS packet", "err", err) + return pkt + } + + mm.debugMDNS(&msg) + + var dirty bool + + if local { + for _, ans := range msg.Answers { + if mm.fixResource(&ans) { + dirty = true + } + } + + for _, add := range msg.Additionals { + if mm.fixResource(&add) { + dirty = true + } + } + } else if msg.Response { + // RFC 6762: + // Multicast DNS responses MUST NOT contain any questions in the + // Question Section. Any questions in the Question Section of a + // received Multicast DNS response MUST be silently ignored. Multicast + // DNS queriers receiving Multicast DNS responses do not care what + // question elicited the response; they care only that the information + // in the response is true and accurate. + // + // f.e. avahi doesn't properly work if the questions section is filled out, so we need to process that. + // + // The likes of Apple's mDNSResponder haven't gotten this above message, so we need to check for this. + // + // TODO this may be because we're regarding it via unicast DNS, and then this is the fallback described in + // section 6.7? + // If the source UDP port in a received Multicast DNS query is not port + // 5353, this indicates that the querier originating the query is a + // simple resolver such as described in Section 5.1, "One-Shot Multicast + // DNS Queries", which does not fully implement all of Multicast DNS. + // In this case, the Multicast DNS responder MUST send a UDP response + // directly back to the querier, via unicast, to the query packet's + // source IP address and port. This unicast response MUST be a + // conventional unicast response as would be generated by a conventional + // Unicast DNS server; for example, it MUST repeat the query ID and the + // question given in the query message. In addition, the cache-flush + // bit described in Section 10.2, "Announcements to Flush Outdated Cache + // Entries", MUST NOT be set in legacy unicast responses. + if len(msg.Questions) != 0 { + msg.Questions = []dnsmessage.Question{} + dirty = true + } + } + + if dirty { + L(mm).Debug("processMDNS: rewritten request") + + mm.debugMDNS(&msg) + + ret, err := msg.Pack() + if err != nil { + L(mm).Warn("failed to pack MDNS packet", "err", err) + return pkt + } + + return ret + } + + return pkt +} + +func (mm *MDNSManager) deadRun() { + for { + select { + case <-mm.inbox: + case <-mm.ctx.Done(): + return + } + } +} + +func (mm *MDNSManager) Close() { + mm.rlStore.Close(context.Background()) +} diff --git a/toversok/actors/a_relay.go b/toversok/actors/a_relay.go index b3ba5a6..dc9b8bf 100644 --- a/toversok/actors/a_relay.go +++ b/toversok/actors/a_relay.go @@ -2,7 +2,13 @@ package actors import ( "context" + "errors" "fmt" + "log/slog" + "runtime" + "runtime/debug" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/dial" "github.com/edup2p/common/types/ifaces" @@ -10,10 +16,6 @@ import ( "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgsess" "github.com/edup2p/common/types/relay" - "github.com/edup2p/common/types/relay/relayhttp" - "log/slog" - "runtime" - "time" ) // RestartableRelayConn is a Relay connection that will automatically reconnect, @@ -25,7 +27,7 @@ type RestartableRelayConn struct { config relay.Information - client *relay.Client + client relay.Client stay bool @@ -52,6 +54,8 @@ func (c *RestartableRelayConn) Poke() { } func (c *RestartableRelayConn) Run() { + defer c.Cancel() + for { if c.shouldIdle() { select { @@ -86,11 +90,8 @@ func (c *RestartableRelayConn) Run() { c.connected = false // Possibly the client exited because the relayConn is being closed, check for that first - select { - case <-c.ctx.Done(): + if c.ctx.Err() != nil { return - default: - // fallthrough } if err != nil { c.L().Warn("relay client exited", "error", err) @@ -119,7 +120,7 @@ func (c *RestartableRelayConn) establish() (success bool) { } var err error - c.client, err = relayhttp.Dial(c.ctx, dial.Opts{ + c.client, err = c.man.s.dialRelayFunc(c.ctx, dial.Opts{ Domain: c.config.Domain, Addrs: types.SliceOrNil(c.config.IPs), Port: port, @@ -137,8 +138,7 @@ func (c *RestartableRelayConn) establish() (success bool) { return false } - go c.client.RunSend() - go c.client.RunReceive() + go c.client.Run() c.L().Debug("established") @@ -161,7 +161,7 @@ func (c *RestartableRelayConn) loop() error { case <-checker.C: if c.shouldIdle() { - c.client.Close() + c.client.Cancel(errors.New("should idle")) return nil } @@ -186,7 +186,7 @@ func (c *RestartableRelayConn) loop() error { } func (c *RestartableRelayConn) Close() { - c.ctxCan() + // TODO nothing much to close? } // Queue queues the pkt for dst in a non-blocking fashion @@ -212,7 +212,7 @@ func (c *RestartableRelayConn) Update(info relay.Information) { // Close the client to trigger a reconnect if c.client != nil { - c.client.Close() + c.client.Cancel(errors.New("relay client exited")) } } @@ -268,7 +268,7 @@ type RelayManager struct { const HomeRelayChangeInterval = time.Minute * 5 func (s *Stage) makeRM() *RelayManager { - return &RelayManager{ + return assureClose(&RelayManager{ ActorCommon: MakeCommon(s.Ctx, RelayManInboxChLen), s: s, homeRelay: 0, @@ -276,29 +276,28 @@ func (s *Stage) makeRM() *RelayManager { relays: make(map[int64]RelayConnActor), inCh: make(chan ifaces.RelayedPeerFrame, RelayManFrameChLen), writeCh: make(chan relayWriteRequest, RelayManWriteChLen), - } + }) } func (rm *RelayManager) Run() { + if !rm.running.CheckOrMark() { + L(rm).Warn("tried to run agent, while already running") + return + } + + defer rm.Cancel() defer func() { if v := recover(); v != nil { - L(rm).Error("panicked", "panic", v) - rm.Cancel() + L(rm).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(rm.ctx, v) } }() - if !rm.running.CheckOrMark() { - L(rm).Warn("tried to run agent, while already running") - return - } - runtime.LockOSThread() for { select { case <-rm.ctx.Done(): - rm.Close() return case m := <-rm.inbox: switch m := m.(type) { @@ -364,16 +363,15 @@ func (rm *RelayManager) Run() { rm.s.RRouter.Push(frame) } } - } func (rm *RelayManager) Close() { - rm.ctxCan() + // TODO nothing much to close? } func (rm *RelayManager) selectRelay(latencies map[int64]time.Duration) int64 { - var srid int64 = 0 - var slat = 60 * time.Second + var srid int64 + slat := 60 * time.Second L(rm).Debug("selectRelay: starting latency check") @@ -403,9 +401,10 @@ func (rm *RelayManager) getConn(id int64) RelayConnActor { func (rm *RelayManager) update(info relay.Information) { if r, ok := rm.relays[info.ID]; ok { r.Update(info) + return } - r := &RestartableRelayConn{ + r := assureClose(&RestartableRelayConn{ ActorCommon: MakeCommon(rm.ctx, -1), man: rm, config: info, @@ -413,7 +412,7 @@ func (rm *RelayManager) update(info relay.Information) { stay: info.ID == rm.homeRelay, bufferCh: make(chan relay.SendPacket, RelayConnSendBufferSize), pokeCh: make(chan interface{}, 1), - } + }) go r.Run() @@ -442,11 +441,11 @@ type RelayRouter struct { } func (s *Stage) makeRR() *RelayRouter { - return &RelayRouter{ + return assureClose(&RelayRouter{ ActorCommon: MakeCommon(s.Ctx, -1), s: s, frameCh: make(chan ifaces.RelayedPeerFrame, RelayRouterFrameChLen), - } + }) } func (rr *RelayRouter) Push(frame ifaces.RelayedPeerFrame) { @@ -456,25 +455,24 @@ func (rr *RelayRouter) Push(frame ifaces.RelayedPeerFrame) { } func (rr *RelayRouter) Run() { + if !rr.running.CheckOrMark() { + L(rr).Warn("tried to run agent, while already running") + return + } + + defer rr.Cancel() defer func() { if v := recover(); v != nil { - L(rr).Warn("panicked", "error", v) - rr.Cancel() + L(rr).Warn("panicked", "error", v, "stack", string(debug.Stack())) bail(rr.ctx, v) } }() - if !rr.running.CheckOrMark() { - L(rr).Warn("tried to run agent, while already running") - return - } - runtime.LockOSThread() for { select { case <-rr.ctx.Done(): - rr.Close() return case frame := <-rr.frameCh: if msgsess.LooksLikeSessionWireMessage(frame.Pkt) { @@ -490,7 +488,11 @@ func (rr *RelayRouter) Run() { in := rr.s.InConnFor(frame.SrcPeer) if in == nil { - // todo log? metric? + L(rr).Debug( + "received incoming relay frame from peer that we don't know about (yet)", + "from-peer", frame.SrcPeer.Debug(), + "from-relay", frame.SrcRelay, + ) continue } diff --git a/toversok/actors/a_relay_test.go b/toversok/actors/a_relay_test.go index d8e640b..3250907 100644 --- a/toversok/actors/a_relay_test.go +++ b/toversok/actors/a_relay_test.go @@ -3,6 +3,7 @@ package actors import ( "context" "fmt" + "slices" "testing" "github.com/edup2p/common/types/ifaces" @@ -19,9 +20,10 @@ func TestRelayManager(t *testing.T) { frameCh: make(chan ifaces.RelayedPeerFrame, RelayRouterFrameChLen), } - var relayID int64 = 0 + const RelayID int64 = 0 + homeRelay := &RestartableRelayConn{ - config: relay.Information{ID: relayID}, + config: relay.Information{ID: RelayID}, bufferCh: make(chan relay.SendPacket), } @@ -32,7 +34,7 @@ func TestRelayManager(t *testing.T) { // Make and run RelayManager rm := s.makeRM() - rm.relays[relayID] = homeRelay + rm.relays[RelayID] = homeRelay go rm.Run() // Message that should be sent to the relay @@ -89,7 +91,7 @@ func TestRelayRouter(t *testing.T) { go rr.Run() // Message that should be sent to SessionManager - sessionPkt := append(msgsess.MagicBytes, zeroBytes(56)...) + sessionPkt := slices.Concat(msgsess.MagicBytes, zeroBytes(56)) frameSession := ifaces.RelayedPeerFrame{ SrcRelay: 0, @@ -104,13 +106,13 @@ func TestRelayRouter(t *testing.T) { // For each peer: register peer in RelayRouter and Stage, and then send a message to their inConn for i, b := range []byte{1, 2} { ic := ics[i] - key := ic.peer + peer := ic.peer - s.inConn[key] = ic + s.inConn[peer] = ic frame := ifaces.RelayedPeerFrame{ SrcRelay: 0, - SrcPeer: key, + SrcPeer: peer, Pkt: []byte{b}, } diff --git a/toversok/actors/a_sman.go b/toversok/actors/a_sman.go index 9497164..c8acdaf 100644 --- a/toversok/actors/a_sman.go +++ b/toversok/actors/a_sman.go @@ -2,10 +2,12 @@ package actors import ( "fmt" + "runtime/debug" + "slices" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgactor" - msg2 "github.com/edup2p/common/types/msgsess" - "slices" + "github.com/edup2p/common/types/msgsess" ) // SessionManager receives frames from routers and decrypts them, @@ -23,11 +25,11 @@ type SessionManager struct { var DebugSManTakeNodeAsSession = false func (s *Stage) makeSM(priv func() *key.SessionPrivate) *SessionManager { - sm := &SessionManager{ + sm := assureClose(&SessionManager{ ActorCommon: MakeCommon(s.Ctx, SessManInboxChLen), s: s, session: priv, - } + }) L(sm).Debug("sman with session key", "sess", priv().Public().Debug()) @@ -35,23 +37,22 @@ func (s *Stage) makeSM(priv func() *key.SessionPrivate) *SessionManager { } func (sm *SessionManager) Run() { + if !sm.running.CheckOrMark() { + L(sm).Warn("tried to run agent, while already running") + return + } + + defer sm.Cancel() defer func() { if v := recover(); v != nil { - L(sm).Error("panicked", "panic", v) - sm.Cancel() + L(sm).Error("panicked", "panic", v, "stack", string(debug.Stack())) bail(sm.ctx, v) } }() - if !sm.running.CheckOrMark() { - L(sm).Warn("tried to run agent, while already running") - return - } - for { select { case <-sm.ctx.Done(): - sm.Close() return case inMsg := <-sm.inbox: sm.Handle(inMsg) @@ -102,12 +103,13 @@ func (sm *SessionManager) Handle(msg msgactor.ActorMessage) { } } -func (sm *SessionManager) Unpack(frameWithMagic []byte) (*msg2.ClearMessage, error) { - if string(frameWithMagic[:len(msg2.Magic)]) != msg2.Magic { +func (sm *SessionManager) Unpack(frameWithMagic []byte) (*msgsess.ClearMessage, error) { + if string(frameWithMagic[:len(msgsess.Magic)]) != msgsess.Magic { + // We check these messages further up, so while this is a safety check, it shouldn't be triggered panic("Somehow received non-session message in unpack") } - b := frameWithMagic[len(msg2.Magic):] + b := frameWithMagic[len(msgsess.Magic):] sessionKey := key.MakeSessionPublic([key.Len]byte(b[:key.Len])) @@ -119,24 +121,23 @@ func (sm *SessionManager) Unpack(frameWithMagic []byte) (*msg2.ClearMessage, err return nil, fmt.Errorf("could not decrypt session message") } - sMsg, err := msg2.ParseSessionMessage(clearBytes) - + sMsg, err := msgsess.ParseSessionMessage(clearBytes) if err != nil { return nil, fmt.Errorf("could not parse session message: %s", err) } - return &msg2.ClearMessage{ + return &msgsess.ClearMessage{ Session: sessionKey, Message: sMsg, }, nil } -func (sm *SessionManager) Pack(sMsg msg2.SessionMessage, toSession key.SessionPublic) []byte { - clearBytes := sMsg.MarshalSessionMessage() +func (sm *SessionManager) Pack(sMsg msgsess.SessionMessage, toSession key.SessionPublic) []byte { + clearBytes := sMsg.Marshal() cipherBytes := sm.session().Shared(toSession).Seal(clearBytes) - return slices.Concat(msg2.MagicBytes, sm.session().Public().ToByteSlice(), cipherBytes) + return slices.Concat(msgsess.MagicBytes, sm.session().Public().ToByteSlice(), cipherBytes) } func (sm *SessionManager) Session() key.SessionPublic { diff --git a/toversok/actors/a_sman_test.go b/toversok/actors/a_sman_test.go index 2896361..6ea6808 100644 --- a/toversok/actors/a_sman_test.go +++ b/toversok/actors/a_sman_test.go @@ -6,25 +6,28 @@ import ( "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgsess" - msg2 "github.com/edup2p/common/types/msgsess" "github.com/stretchr/testify/assert" ) // Mock Session Message used in this test type MockSessionMessage struct { - marshalSessionMessage func() []byte - debug func() string + marshal func() []byte + debug func() string } -func (m *MockSessionMessage) MarshalSessionMessage() []byte { - return m.marshalSessionMessage() +func (m *MockSessionMessage) Marshal() []byte { + return m.marshal() } func (m *MockSessionMessage) Debug() string { return m.debug() } -func assertEncryptedPacket(t *testing.T, pkt []byte, sm *SessionManager, expectedDecryption *msg2.ClearMessage, failMsg string) { +func (m *MockSessionMessage) Parse([]byte) error { + panic("implement me") +} + +func assertEncryptedPacket(t *testing.T, pkt []byte, sm *SessionManager, expectedDecryption *msgsess.ClearMessage, failMsg string) { // We cannot predict the encryption with a random nonce, so we unpack the packet in receivedReq to test if it is correct unpacked, ok := sm.Unpack(pkt) assert.Nil(t, ok, "Decryption of packet in received directWriteRequest failed") @@ -53,21 +56,21 @@ func TestSessionManager(t *testing.T) { // Create a test ping message txID := [12]byte{42} pingBytes := append(txID[:], dummyKey[:]...) - clearBytes := append([]byte{1, 0}, pingBytes[:]...) // 1 is version nr, 0 is Ping message + clearBytes := append([]byte{1, 0}, pingBytes...) // 1 is version nr, 0 is Ping message pingMsg := &msgsess.Ping{ TxID: txID, NodeKey: dummyKey, } - clearMsg := &msg2.ClearMessage{ + clearMsg := &msgsess.ClearMessage{ Session: testPub, Message: pingMsg, } // Pack the test ping message mockSessionMsg := &MockSessionMessage{ - marshalSessionMessage: func() []byte { + marshal: func() []byte { return clearBytes }, } @@ -93,7 +96,7 @@ func TestSessionManager(t *testing.T) { assert.Equal(t, expectedFromRelay, receivedFromRelay, "TrafficManager did not receive expected message when sending frame from an address-port pair to SessionManager") - //Test Handle on frame from addrport + // Test Handle on frame from addrport frameFromAddrPort := &msgactor.SManSessionFrameFromAddrPort{ AddrPort: dummyAddrPort, FrameWithMagic: packedBytes, diff --git a/toversok/actors/a_sockrecv.go b/toversok/actors/a_sockrecv.go index 28379b9..374da4a 100644 --- a/toversok/actors/a_sockrecv.go +++ b/toversok/actors/a_sockrecv.go @@ -3,12 +3,14 @@ package actors import ( "context" "errors" - "fmt" - "github.com/edup2p/common/types" + "log/slog" "net" "net/netip" + "runtime/debug" "slices" "time" + + "github.com/edup2p/common/types" ) type RecvFrame struct { @@ -27,40 +29,40 @@ type SockRecv struct { outCh chan RecvFrame } -func MakeSockRecv(udp types.UDPConn, pCtx context.Context) *SockRecv { - - return &SockRecv{ +func MakeSockRecv(ctx context.Context, udp types.UDPConn) *SockRecv { + return assureClose(&SockRecv{ Conn: udp, outCh: make(chan RecvFrame, SockRecvFrameChanBuffer), - ActorCommon: MakeCommon(pCtx, -1), - } + ActorCommon: MakeCommon(ctx, -1), + }) } func (r *SockRecv) Run() { + if !r.running.CheckOrMark() { + L(r).Warn("tried to run agent, while already running") + return + } + + defer r.Cancel() defer func() { if v := recover(); v != nil { - L(r).Error("panicked", "err", v) - r.Cancel() + L(r).Error("panicked", "err", v, "stack", string(debug.Stack())) bail(r.ctx, v) } }() - if !r.running.CheckOrMark() { - L(r).Warn("tried to run agent, while already running") - return - } - - var buf = make([]byte, 1<<16) + buf := make([]byte, 1<<16) for { - if context.Cause(r.ctx) != nil { + if r.ctx.Err() != nil { return } err := r.Conn.SetReadDeadline(time.Now().Add(SockRecvReadTimeout)) if err != nil { - panic(fmt.Sprint("Error when setting read deadline:", err)) + L(r).Error("failed to set read deadline", "err", err) + return } n, ap, err := r.Conn.ReadFromUDPAddrPort(buf) @@ -74,17 +76,16 @@ func (r *SockRecv) Run() { // unsure what to do here, as this might be a permanent error of the socket? // would this result in the closing of the channel? if so, wouldnt the corresponding outconn also die? // if so, then who detects the death of the actor and recreates it like that? - if context.Cause(r.ctx) != nil { + if r.ctx.Err() != nil { // we're closing anyways, just return return } - if errors.Is(err, net.ErrClosed) { - r.Cancel() - return + if !errors.Is(err, net.ErrClosed) { + L(r).Error("failed to read packet", "err", err) } - panic(err) + return } if n == 0 { @@ -93,29 +94,26 @@ func (r *SockRecv) Run() { pkt := slices.Clone(buf[:n]) - if context.Cause(r.ctx) != nil { + if r.ctx.Err() != nil { return } select { - case <-r.ctx.Done(): - r.Close() - return case r.outCh <- RecvFrame{ pkt: pkt, ts: ts, src: ap, }: // fallthrough continue + case <-r.ctx.Done(): + return } } } func (r *SockRecv) Close() { - if context.Cause(r.ctx) == nil { - r.Conn.Close() - close(r.outCh) - r.ctxCan() - return + if err := r.Conn.Close(); err != nil { + slog.Error("failed to close connection for sockrecv", "err", err) } + close(r.outCh) } diff --git a/toversok/actors/a_tman.go b/toversok/actors/a_tman.go index e63b2ab..a749f8a 100644 --- a/toversok/actors/a_tman.go +++ b/toversok/actors/a_tman.go @@ -1,17 +1,20 @@ package actors import ( - "github.com/edup2p/common/toversok/actors/peer_state" + "context" + "maps" + "net/netip" + "runtime/debug" + "time" + + "github.com/edup2p/common/toversok/actors/peerstate" "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgsess" "github.com/edup2p/common/types/stage" - maps2 "golang.org/x/exp/maps" - "maps" - "net/netip" - "time" + xmaps "golang.org/x/exp/maps" ) type TrafficManager struct { @@ -21,7 +24,7 @@ type TrafficManager struct { ticker *time.Ticker // 250ms poke chan interface{} // len 1 - peerState map[key.NodePublic]peer_state.PeerState + peerState map[key.NodePublic]peerstate.PeerState pings map[msgsess.TxID]*stage.SentPing activeOut map[key.NodePublic]bool @@ -32,41 +35,39 @@ type TrafficManager struct { } func (s *Stage) makeTM() *TrafficManager { - return &TrafficManager{ + return assureClose(&TrafficManager{ ActorCommon: MakeCommon(s.Ctx, TrafficManInboxChLen), s: s, ticker: time.NewTicker(TManTickerInterval), poke: make(chan interface{}, 1), - peerState: make(map[key.NodePublic]peer_state.PeerState), + peerState: make(map[key.NodePublic]peerstate.PeerState), pings: make(map[msgsess.TxID]*stage.SentPing), activeOut: make(map[key.NodePublic]bool), activeIn: make(map[key.NodePublic]bool), sessMap: make(map[key.SessionPublic]key.NodePublic), - } + }) } func (tm *TrafficManager) Run() { + if !tm.running.CheckOrMark() { + L(tm).Warn("tried to run agent, while already running") + return + } + + defer tm.Cancel() defer func() { if v := recover(); v != nil { - L(tm).Error("panicked", "error", v) - tm.Cancel() - tm.Close() + L(tm).Error("panicked", "error", v, "stack", string(debug.Stack())) bail(tm.ctx, v) } }() - if !tm.running.CheckOrMark() { - L(tm).Warn("tried to run agent, while already running") - return - } - for { select { case <-tm.ctx.Done(): - tm.Close() return case <-tm.ticker.C: // Run periodic before inbox, as inbox can get backed up, and ping + path management would get delayed. @@ -109,13 +110,25 @@ func (tm *TrafficManager) Handle(m msgactor.ActorMessage) { case *msgactor.TManSessionMessageFromDirect: n := tm.NodeForSess(m.Msg.Session) - if n != nil { - tm.forState(*n, func(s peer_state.PeerState) peer_state.PeerState { - return s.OnDirect(types.NormaliseAddrPort(m.AddrPort), m.Msg) - }) - } else { + if n == nil { L(tm).Warn("got message from direct for unknown session", "session", m.Msg.Session.Debug()) + return + } + + node := *n + + if ok, ip6 := tm.isMDNS(m.Msg); ok { + if !tm.mdnsAllowed(node) { + L(tm).Warn("got direct MDNS packet from peer where it is not allowed", "peer", node.Debug()) + return + } + tm.sendMDNS(node, m.Msg, ip6) + return } + + tm.forState(node, func(s peerstate.PeerState) peerstate.PeerState { + return s.OnDirect(types.NormaliseAddrPort(m.AddrPort), m.Msg) + }) case *msgactor.TManSessionMessageFromRelay: if !tm.ValidKeys(m.Peer, m.Msg.Session) { L(tm).Warn("got message from relay for peer with incorrect session", @@ -123,7 +136,16 @@ func (tm *TrafficManager) Handle(m msgactor.ActorMessage) { return } - tm.forState(m.Peer, func(s peer_state.PeerState) peer_state.PeerState { + if ok, ip6 := tm.isMDNS(m.Msg); ok { + if !tm.mdnsAllowed(m.Peer) { + L(tm).Warn("got relay MDNS packet from peer where it is not allowed", "peer", m.Peer.Debug()) + return + } + tm.sendMDNS(m.Peer, m.Msg, ip6) + return + } + + tm.forState(m.Peer, func(s peerstate.PeerState) peerstate.PeerState { return s.OnRelay(m.Relay, m.Peer, m.Msg) }) case *msgactor.SyncPeerInfo: @@ -144,16 +166,94 @@ func (tm *TrafficManager) Handle(m msgactor.ActorMessage) { return key == m.Peer }) } + case *msgactor.TManSpreadMDNSPacket: + tm.spreadMDNS(m.Pkt, m.IP6) default: tm.logUnknownMessage(m) } } +func (tm *TrafficManager) isMDNS(msg *msgsess.ClearMessage) (isMDNS bool, ip6 bool) { + sbd, ok := msg.Message.(*msgsess.SideBandData) + + if ok { + return sbd.Type == msgsess.MDNSv4Type || sbd.Type == msgsess.MDNSv6Type, sbd.Type == msgsess.MDNSv6Type + } + + return false, false +} + +func (tm *TrafficManager) mdnsAllowed(node key.NodePublic) bool { + pi := tm.s.GetPeerInfo(node) + + if pi == nil { + return false + } + + return pi.MDNS +} + +func (tm *TrafficManager) sendMDNS(peer key.NodePublic, msg *msgsess.ClearMessage, ip6 bool) { + sbd := msg.Message.(*msgsess.SideBandData) + + go SendMessage(tm.s.MMan.Inbox(), &msgactor.MManReceivedPacket{ + From: peer, + Data: sbd.Data, + IP6: ip6, + }) +} + +func (tm *TrafficManager) spreadMDNS(pkt []byte, ip6 bool) { + peers := tm.s.GetPeersWhere(func(_ key.NodePublic, info *stage.PeerInfo) bool { + return info.MDNS + }) + + peersDebug := types.Map(peers, func(t key.NodePublic) string { + return t.Debug() + }) + L(tm).Log(context.Background(), types.LevelTrace, "sending mdns packet to peers", "peers", peersDebug) + + var t msgsess.SideBandDataType + if ip6 { + t = msgsess.MDNSv6Type + } else { + t = msgsess.MDNSv4Type + } + + for _, peer := range peers { + tm.opportunisticSendTo(peer, &msgsess.SideBandData{ + Type: t, + Data: pkt, + }) + } +} + +func (tm *TrafficManager) opportunisticSendTo(to key.NodePublic, msg msgsess.SessionMessage) { + pi := tm.s.GetPeerInfo(to) + + if pi == nil { + L(tm).Warn("trying to send an opportunistic session message to a node without peerinfo", "to", to.Debug()) + return + } + + tm.forState(to, func(s peerstate.PeerState) peerstate.PeerState { + L(tm).Log(context.Background(), types.LevelTrace, "sending opportunistic session message to peer", "peer", to.Debug()) + + if e, ok := s.(*peerstate.Established); ok { + tm.SendMsgToDirect(e.GetEndpoint(), pi.Session, msg) + } else { + tm.SendMsgToRelay(pi.HomeRelay, to, pi.Session, msg) + } + + return nil + }) +} + func (tm *TrafficManager) DoStateTick() { // We explicitly range over a slice of the keys we already got, // since golang likes to complain when we mutate while we iterate. - for _, peer := range maps2.Keys(tm.peerState) { - tm.forState(peer, func(s peer_state.PeerState) peer_state.PeerState { + for _, peer := range xmaps.Keys(tm.peerState) { + tm.forState(peer, func(s peerstate.PeerState) peerstate.PeerState { return s.OnTick() }) } @@ -192,6 +292,7 @@ func (tm *TrafficManager) Poke() { } } +//nolint:unused func (tm *TrafficManager) isConnActive(peer key.NodePublic) bool { return tm.activeOut[peer] || tm.activeIn[peer] } @@ -222,7 +323,7 @@ func (tm *TrafficManager) ensurePeerState(peer key.NodePublic) { s, ok := tm.peerState[peer] if !ok { - tm.peerState[peer] = peer_state.MakeWaiting(tm, peer) + tm.peerState[peer] = peerstate.MakeWaiting(tm, peer) tm.Poke() return } @@ -230,7 +331,7 @@ func (tm *TrafficManager) ensurePeerState(peer key.NodePublic) { if s == nil { // !! this should never happen, but we recover regardless L(tm).Warn("found nil state for peer, restarting state with Waiting", "peer", peer.Debug()) - tm.peerState[peer] = peer_state.MakeWaiting(tm, peer) + tm.peerState[peer] = peerstate.MakeWaiting(tm, peer) tm.Poke() } } @@ -239,30 +340,31 @@ func (tm *TrafficManager) Close() { tm.ticker.Stop() } +const PingReapTimeout = 10 * time.Minute + func (tm *TrafficManager) doPingManagement() { - // TODO - // - expire old pings + var oldPings []msgsess.TxID + + for txid, ping := range tm.pings { + if ping.At.Add(PingReapTimeout).Before(time.Now()) { + oldPings = append(oldPings, txid) + } + } + + for _, txid := range oldPings { + delete(tm.pings, txid) + } } -type StateForState func(state peer_state.PeerState) peer_state.PeerState +type StateForState func(state peerstate.PeerState) peerstate.PeerState func (tm *TrafficManager) forState(peer key.NodePublic, fn StateForState) { // A state for a state, perfectly balanced, as all things should be. // - Thanos, while writing this code. - state, ok := tm.peerState[peer] - - if !ok { - return - } - - if state == nil { - L(tm).Error("found nil state when running update for peer, recovering...", "peer", peer.Debug()) - tm.ensurePeerState(peer) - state = tm.peerState[peer] - } + tm.ensurePeerState(peer) - newState := fn(state) + newState := fn(tm.peerState[peer]) if newState != nil { // state transitions have happened, store the new state @@ -270,12 +372,6 @@ func (tm *TrafficManager) forState(peer key.NodePublic, fn StateForState) { } } -// TODO see if these correspond in peer_state package -//const EstablishmentTimeout = time.Second * 10 -//const EstablishmentRetry = time.Second * 40 -// -//const EstablishedPingTimeout = time.Second * 5 - func (tm *TrafficManager) DManClearAKA(peer key.NodePublic) { SendMessage(tm.s.DRouter.Inbox(), &msgactor.DRouterPeerClearKnownAs{ Peer: peer, @@ -337,8 +433,10 @@ func (tm *TrafficManager) ValidKeys(peer key.NodePublic, session key.SessionPubl } func (tm *TrafficManager) SendPingDirect(endpoint netip.AddrPort, peer key.NodePublic, session key.SessionPublic) { - txid := msgsess.NewTxID() + tm.SendPingDirectWithID(endpoint, peer, session, msgsess.NewTxID()) +} +func (tm *TrafficManager) SendPingDirectWithID(endpoint netip.AddrPort, peer key.NodePublic, session key.SessionPublic, txid msgsess.TxID) { nep := types.NormaliseAddrPort(endpoint) tm.SendMsgToDirect(nep, session, &msgsess.Ping{ diff --git a/toversok/actors/a_tman_test.go b/toversok/actors/a_tman_test.go index dcda7b8..f912019 100644 --- a/toversok/actors/a_tman_test.go +++ b/toversok/actors/a_tman_test.go @@ -19,9 +19,9 @@ func TestTrafficManager(t *testing.T) { wgConn := &MockUDPConn{} oc := MakeOutConn(wgConn, dummyKey, 0, s) - s.outConn = make(map[key.NodePublic]OutConnActor, 0) + s.outConn = make(map[key.NodePublic]OutConnActor) s.outConn[dummyKey] = oc - s.peerInfo = make(map[key.NodePublic]*stage.PeerInfo, 0) + s.peerInfo = make(map[key.NodePublic]*stage.PeerInfo) s.peerInfo[dummyKey] = &stage.PeerInfo{Session: testPub} // Run TrafficManager diff --git a/toversok/actors/common.go b/toversok/actors/common.go index a5cbb19..a3fc704 100644 --- a/toversok/actors/common.go +++ b/toversok/actors/common.go @@ -2,8 +2,9 @@ package actors import ( "context" - "github.com/edup2p/common/types/msgactor" "log/slog" + + "github.com/edup2p/common/types/msgactor" ) type ActorCommon struct { @@ -16,7 +17,7 @@ type ActorCommon struct { func MakeCommon(pCtx context.Context, chLen int) *ActorCommon { ctx, ctxCan := context.WithCancel(pCtx) - var inbox chan msgactor.ActorMessage = nil + var inbox chan msgactor.ActorMessage if chLen >= 0 { inbox = make(chan msgactor.ActorMessage, chLen) @@ -38,6 +39,10 @@ func (ac *ActorCommon) Cancel() { ac.ctxCan() } +func (ac *ActorCommon) Ctx() context.Context { + return ac.ctx +} + func (ac *ActorCommon) logUnknownMessage(am msgactor.ActorMessage) { // TODO make better; somehow get actor name in there slog.Error("got unknown message", "ac", ac, "am", am) diff --git a/toversok/actors/consts.go b/toversok/actors/consts.go index 8bb4ac5..2e6aceb 100644 --- a/toversok/actors/consts.go +++ b/toversok/actors/consts.go @@ -15,6 +15,7 @@ const ( TrafficManInboxChLen = 16 RelayManInboxChLen = 4 DirectRouterInboxChLen = 4 + MdnsManInboxChLen = 32 // Frame SockRecvFrameChanBuffer = 256 diff --git a/toversok/actors/peer_state/common.go b/toversok/actors/peerstate/common.go similarity index 61% rename from toversok/actors/peer_state/common.go rename to toversok/actors/peerstate/common.go index 8b51f33..e814c22 100644 --- a/toversok/actors/peer_state/common.go +++ b/toversok/actors/peerstate/common.go @@ -1,15 +1,17 @@ -package peer_state +package peerstate import ( "context" + "errors" + "log/slog" + "net/netip" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" "github.com/edup2p/common/types/stage" - "log/slog" - "net/netip" - "time" ) const ( @@ -30,7 +32,7 @@ func (sc *StateCommon) Peer() key.NodePublic { return sc.peer } -func (sc *StateCommon) pingDirectValid(ap netip.AddrPort, sess key.SessionPublic, ping *msgsess.Ping) bool { +func (sc *StateCommon) pingDirectValid(_ netip.AddrPort, sess key.SessionPublic, ping *msgsess.Ping) bool { return sc.tm.ValidKeys(ping.NodeKey, sess) } @@ -41,7 +43,8 @@ func (sc *StateCommon) replyWithPongDirect(ap netip.AddrPort, sess key.SessionPu }) } -func (sc *StateCommon) pingRelayValid(relay int64, node key.NodePublic, sess key.SessionPublic, ping *msgsess.Ping) bool { +//nolint:unused +func (sc *StateCommon) pingRelayValid(_ int64, _ key.NodePublic, sess key.SessionPublic, ping *msgsess.Ping) bool { return sc.tm.ValidKeys(ping.NodeKey, sess) } @@ -51,63 +54,114 @@ func (sc *StateCommon) replyWithPongRelay(relay int64, node key.NodePublic, sess }) } -// TODO add bool here and checks by callers -func (sc *StateCommon) ackPongDirect(ap netip.AddrPort, sess key.SessionPublic, pong *msgsess.Pong) { +func (sc *StateCommon) pongDirectValid(ap netip.AddrPort, sess key.SessionPublic, pong *msgsess.Pong) error { sent, ok := sc.tm.Pings()[pong.TxID] if !ok { - // TODO log: Got pong for unknown ping - return + slog.Warn( + "got pong for unknown ping", + "from-ap", ap, + "txid", pong.TxID, + "sess", sess, + ) + return errors.New("pong txid does not correspond to any sent ping") } if sent.ToRelay { - // TODO log: got direct pong to relay ping - return + slog.Warn( + "got direct pong to relay ping", + "from-ap", ap, + "txid", pong.TxID, + "ping-to", sent.To.Debug(), + "to-relay", sent.RelayID, + "sess", sess, + ) + return errors.New("direct pong is reply to relay ping") } if !sc.tm.ValidKeys(sc.peer, sess) { // ?? Somehow the pong is for a valid ping to a node that no longer has this session key? // Might happen between restarts, log and ignore. - // TODO log - return + slog.Warn( + "received valid pong for unexpected remote session", + "from-ap", ap, + "txid", pong.TxID, + "sess", sess, + ) + return errors.New("got pong from invalid session") } // TODO more checks? (permissive, but log) + return nil +} + +func (sc *StateCommon) clearPongDirect(_ netip.AddrPort, _ key.SessionPublic, pong *msgsess.Pong) { delete(sc.tm.Pings(), pong.TxID) } // TODO add bool here and checks by callers -func (sc *StateCommon) ackPongRelay(relay int64, node key.NodePublic, sess key.SessionPublic, pong *msgsess.Pong) { - +func (sc *StateCommon) ackPongRelay(relayID int64, node key.NodePublic, sess key.SessionPublic, pong *msgsess.Pong) { // Relay pongs should come in response to relay pings, note if it is different. sent, ok := sc.tm.Pings()[pong.TxID] if !ok { - // TODO log: Got pong for unknown ping + slog.Warn( + "got pong for unknown ping", + "from-relay", relayID, + "txid", pong.TxID, + "sess", sess, + ) return } if !sent.ToRelay { - // TODO log: got relay reply to direct ping + slog.Warn( + "got relay pong to direct ping", + "from-relay", relayID, + "txid", pong.TxID, + "ping-to", sent.To.Debug(), + "to-relay", sent.RelayID, + "sess", sess, + ) return } - if !sc.tm.ValidKeys(node, sess) { - // TODO log + if node != sent.To { + slog.Warn( + "received pong to ping (with same TXID) from a different peer than we sent it to, possible collision", + "to-peer", sent.To.Debug(), + "from-peer", node.Debug(), + "from-relay", relayID, + "txid", pong.TxID, + "sess", sess, + ) return } if !sc.tm.ValidKeys(sent.To, sess) { // ?? Somehow the pong is for a valid ping to a node that no longer has this session key? // Might happen between restarts, log and ignore. - // TODO log + slog.Warn( + "received valid pong for unexpected remote session", + "from-relay", relayID, + "txid", pong.TxID, + "sess", sess, + ) return } + if sent.RelayID != relayID { + slog.Debug( + "received relay pong to relay ping from other relay, ignoring...", + "to-relay", sent.RelayID, + "from-relay", relayID, + "txid", pong.TxID, + ) + } + // TODO more checks? (permissive, but log) delete(sc.tm.Pings(), pong.TxID) - } func (sc *StateCommon) getPeerInfo() *stage.PeerInfo { @@ -122,14 +176,26 @@ type EstablishingCommon struct { lastPing time.Time pingCount uint + + tracker *PingTracker } func mkEstComm(sc *StateCommon, attempts int) *EstablishingCommon { - ec := &EstablishingCommon{StateCommon: sc, attempt: attempts + 1} + ec := &EstablishingCommon{ + StateCommon: sc, + attempt: attempts + 1, + tracker: NewPingTracker(), + } ec.resetDeadline() return ec } +func (ec *EstablishingCommon) ackPongDirect(ap netip.AddrPort, sess key.SessionPublic, pong *msgsess.Pong) { + ec.tracker.GotPong(ap) + + ec.clearPongDirect(ap, sess, pong) +} + func (ec *EstablishingCommon) resetDeadline() { ec.deadline = time.Now().Add(EstablishmentTimeout) } diff --git a/toversok/actors/peer_state/e_half.go b/toversok/actors/peerstate/e_half.go similarity index 50% rename from toversok/actors/peer_state/e_half.go rename to toversok/actors/peerstate/e_half.go index 25b9533..c117c51 100644 --- a/toversok/actors/peer_state/e_half.go +++ b/toversok/actors/peerstate/e_half.go @@ -1,10 +1,11 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type EstHalf struct { @@ -28,64 +29,67 @@ func (e *EstHalf) OnTick() PeerState { return nil } -func (e *EstHalf) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(e, ap, clear); s != nil { +func (e *EstHalf) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(e, ap, clearMsg); s != nil { return s } - LogDirectMessage(e, ap, clear) + LogDirectMessage(e, ap, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - if !e.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !e.pingDirectValid(ap, clearMsg.Session, m) { L(e).Warn("dropping invalid ping", "ap", ap.String()) return nil } - e.replyWithPongDirect(ap, clear.Session, m) + e.replyWithPongDirect(ap, clearMsg.Session, m) // Send one as a hail-mary, for if another got lost - e.tm.SendPingDirect(ap, e.peer, clear.Session) + e.tm.SendPingDirect(ap, e.peer, clearMsg.Session) e.lastPing = time.Now() return nil - case *msg2.Pong: + case *msgsess.Pong: + if err := e.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(e).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + return nil + } + e.tm.Poke() return LogTransition(e, &Finalizing{ EstablishingCommon: e.EstablishingCommon, ap: ap, - sess: clear.Session, + sess: clearMsg.Session, pong: m, }) - //case *msg.Rendezvous: default: L(e).Warn("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (e *EstHalf) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(e, relay, peer, clear); s != nil { +func (e *EstHalf) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(e, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(e, relay, peer, clear) + LogRelayMessage(e, relay, peer, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - e.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + e.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - e.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + e.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - //case *msg.Rendezvous: default: L(e).Warn("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } diff --git a/toversok/actors/peer_state/e_rendez.go b/toversok/actors/peerstate/e_rendez.go similarity index 53% rename from toversok/actors/peer_state/e_rendez.go rename to toversok/actors/peerstate/e_rendez.go index 7b08253..8fbd8ac 100644 --- a/toversok/actors/peer_state/e_rendez.go +++ b/toversok/actors/peerstate/e_rendez.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type EstRendezAck struct { @@ -27,16 +28,16 @@ func (e *EstRendezAck) OnTick() PeerState { return nil } -func (e *EstRendezAck) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(e, ap, clear); s != nil { +func (e *EstRendezAck) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(e, ap, clearMsg); s != nil { return s } - LogDirectMessage(e, ap, clear) + LogDirectMessage(e, ap, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - if !e.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !e.pingDirectValid(ap, clearMsg.Session, m) { L(e).Warn("dropping invalid ping", "ap", ap.String()) return nil } @@ -45,47 +46,50 @@ func (e *EstRendezAck) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) Pee return LogTransition(e, &EstHalfIng{ EstablishingCommon: e.EstablishingCommon, ap: ap, - sess: clear.Session, + sess: clearMsg.Session, ping: m, }) - case *msg2.Pong: + case *msgsess.Pong: + if err := e.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(e).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + return nil + } + e.tm.Poke() return LogTransition(e, &Finalizing{ EstablishingCommon: e.EstablishingCommon, ap: ap, - sess: clear.Session, + sess: clearMsg.Session, pong: m, }) - //case *msg.Rendezvous: default: L(e).Warn("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (e *EstRendezAck) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(e, relay, peer, clear); s != nil { +func (e *EstRendezAck) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(e, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(e, relay, peer, clear) + LogRelayMessage(e, relay, peer, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - e.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + e.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - e.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + e.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - //case *msg.Rendezvous: default: L(e).Warn("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } diff --git a/toversok/actors/peer_state/e_t_finalising.go b/toversok/actors/peerstate/e_t_finalising.go similarity index 54% rename from toversok/actors/peer_state/e_t_finalising.go rename to toversok/actors/peerstate/e_t_finalising.go index 9d00764..f5a1f92 100644 --- a/toversok/actors/peer_state/e_t_finalising.go +++ b/toversok/actors/peerstate/e_t_finalising.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type Finalizing struct { @@ -11,7 +12,7 @@ type Finalizing struct { ap netip.AddrPort sess key.SessionPublic - pong *msg2.Pong + pong *msgsess.Pong } func (f *Finalizing) Name() string { @@ -21,18 +22,25 @@ func (f *Finalizing) Name() string { func (f *Finalizing) OnTick() PeerState { f.ackPongDirect(f.ap, f.sess, f.pong) + bap, err := f.tracker.BestAddrPort() + if err != nil { + // We just acked a pong, so there should at least be 1 pair in there, so panic + panic(err) + } + return LogTransition(f, &Booting{ StateCommon: f.StateCommon, - ap: f.ap, + tracker: f.tracker, + ap: bap, }) } -func (f *Finalizing) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { +func (f *Finalizing) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(f, ap, clear) + return cascadeDirect(f, ap, clearMsg) } -func (f *Finalizing) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { +func (f *Finalizing) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(f, relay, peer, clear) + return cascadeRelay(f, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/e_t_half.go b/toversok/actors/peerstate/e_t_half.go similarity index 69% rename from toversok/actors/peer_state/e_t_half.go rename to toversok/actors/peerstate/e_t_half.go index 464771f..0b4ef65 100644 --- a/toversok/actors/peer_state/e_t_half.go +++ b/toversok/actors/peerstate/e_t_half.go @@ -1,10 +1,11 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type EstHalfIng struct { @@ -12,7 +13,7 @@ type EstHalfIng struct { ap netip.AddrPort sess key.SessionPublic - ping *msg2.Ping + ping *msgsess.Ping } func (e *EstHalfIng) Name() string { @@ -28,12 +29,12 @@ func (e *EstHalfIng) OnTick() PeerState { return LogTransition(e, &EstHalf{EstablishingCommon: e.EstablishingCommon}) } -func (e *EstHalfIng) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { +func (e *EstHalfIng) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(e, ap, clear) + return cascadeDirect(e, ap, clearMsg) } -func (e *EstHalfIng) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { +func (e *EstHalfIng) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(e, relay, peer, clear) + return cascadeRelay(e, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/e_t_pretransmit.go b/toversok/actors/peerstate/e_t_pretransmit.go similarity index 71% rename from toversok/actors/peer_state/e_t_pretransmit.go rename to toversok/actors/peerstate/e_t_pretransmit.go index 1ef3481..8299701 100644 --- a/toversok/actors/peer_state/e_t_pretransmit.go +++ b/toversok/actors/peerstate/e_t_pretransmit.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type EstPreTransmit struct { @@ -27,19 +28,19 @@ func (e *EstPreTransmit) OnTick() PeerState { if len(endpoints) > 0 { e.tm.SendMsgToRelay( pi.HomeRelay, e.peer, pi.Session, - &msg2.Rendezvous{MyAddresses: endpoints}, + &msgsess.Rendezvous{MyAddresses: endpoints}, ) } return LogTransition(e, &EstTransmitting{EstablishingCommon: e.EstablishingCommon}) } -func (e *EstPreTransmit) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { +func (e *EstPreTransmit) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(e, ap, clear) + return cascadeDirect(e, ap, clearMsg) } -func (e *EstPreTransmit) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { +func (e *EstPreTransmit) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(e, relay, peer, clear) + return cascadeRelay(e, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/e_t_rendez.go b/toversok/actors/peerstate/e_t_rendez.go similarity index 76% rename from toversok/actors/peer_state/e_t_rendez.go rename to toversok/actors/peerstate/e_t_rendez.go index d7246b8..59563ce 100644 --- a/toversok/actors/peer_state/e_t_rendez.go +++ b/toversok/actors/peerstate/e_t_rendez.go @@ -1,18 +1,19 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) // EstRendezGot is a transient state that immediately transitions to EstRendezAck after the first OnTick type EstRendezGot struct { *EstablishingCommon - m *msg2.Rendezvous + m *msgsess.Rendezvous } func (e *EstRendezGot) Name() string { @@ -39,12 +40,12 @@ func (e *EstRendezGot) OnTick() PeerState { return LogTransition(e, &EstRendezAck{EstablishingCommon: e.EstablishingCommon}) } -func (e *EstRendezGot) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { +func (e *EstRendezGot) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(e, ap, clear) + return cascadeDirect(e, ap, clearMsg) } -func (e *EstRendezGot) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { +func (e *EstRendezGot) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(e, relay, peer, clear) + return cascadeRelay(e, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/e_transmitting.go b/toversok/actors/peerstate/e_transmitting.go similarity index 67% rename from toversok/actors/peer_state/e_transmitting.go rename to toversok/actors/peerstate/e_transmitting.go index 8df50e6..a35b999 100644 --- a/toversok/actors/peer_state/e_transmitting.go +++ b/toversok/actors/peerstate/e_transmitting.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type EstTransmitting struct { @@ -27,16 +28,16 @@ func (e *EstTransmitting) OnTick() PeerState { return nil } -func (e *EstTransmitting) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(e, ap, clear); s != nil { +func (e *EstTransmitting) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(e, ap, clearMsg); s != nil { return s } - LogDirectMessage(e, ap, clear) + LogDirectMessage(e, ap, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - if !e.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !e.pingDirectValid(ap, clearMsg.Session, m) { L(e).Warn("dropping invalid ping", "ap", ap.String()) return nil } @@ -45,33 +46,37 @@ func (e *EstTransmitting) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) return LogTransition(e, &EstHalfIng{ EstablishingCommon: e.EstablishingCommon, ap: ap, - sess: clear.Session, + sess: clearMsg.Session, ping: m, }) - case *msg2.Pong: + case *msgsess.Pong: + if err := e.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(e).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + return nil + } + e.tm.Poke() return LogTransition(e, &Finalizing{ EstablishingCommon: e.EstablishingCommon, ap: ap, - sess: clear.Session, + sess: clearMsg.Session, pong: m, }) - //case *msg.Rendezvous: default: L(e).Warn("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (e *EstTransmitting) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(e, relay, peer, clear); s != nil { +func (e *EstTransmitting) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(e, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(e, relay, peer, clear) + LogRelayMessage(e, relay, peer, clearMsg) // NOTE: There an edgecase that can happen here: // @@ -93,21 +98,21 @@ func (e *EstTransmitting) OnRelay(relay int64, peer key.NodePublic, clear *msg2. // // This is harmless, as the state diagram permits for it, but its worth noting. - switch m := clear.Message.(type) { - case *msg2.Ping: - e.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + e.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - e.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + e.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Rendezvous: + case *msgsess.Rendezvous: e.tm.Poke() return LogTransition(e, &EstRendezGot{EstablishingCommon: e.EstablishingCommon, m: m}) default: L(e).Warn("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } diff --git a/toversok/actors/peer_state/iface.go b/toversok/actors/peerstate/iface.go similarity index 79% rename from toversok/actors/peer_state/iface.go rename to toversok/actors/peerstate/iface.go index ab2b87c..ff74157 100644 --- a/toversok/actors/peer_state/iface.go +++ b/toversok/actors/peerstate/iface.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( + "net/netip" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" - "net/netip" ) // This state pattern was inspired by https://refactoring.guru/design-patterns/state/go/example @@ -14,8 +15,8 @@ import ( // If it's non-nil, replace the state for the peer with the state returned. type PeerState interface { OnTick() PeerState - OnDirect(ap netip.AddrPort, clear *msgsess.ClearMessage) PeerState - OnRelay(relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) PeerState + OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState + OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState // Name returns a lower-case name to be used in logging. Name() string diff --git a/toversok/actors/peer_state/peer_state.go b/toversok/actors/peerstate/peer_state.go similarity index 96% rename from toversok/actors/peer_state/peer_state.go rename to toversok/actors/peerstate/peer_state.go index 63ca4ab..00f3bed 100644 --- a/toversok/actors/peer_state/peer_state.go +++ b/toversok/actors/peerstate/peer_state.go @@ -5,4 +5,4 @@ // pongs haven't been received for 5 seconds (with pings at 2 second intervals). // // See [peer_state.mermaid] for a primary reference of this state machine. -package peer_state +package peerstate diff --git a/toversok/actors/peer_state/peer_state.mermaid b/toversok/actors/peerstate/peer_state.mermaid similarity index 100% rename from toversok/actors/peer_state/peer_state.mermaid rename to toversok/actors/peerstate/peer_state.mermaid diff --git a/toversok/actors/peerstate/pingtracker.go b/toversok/actors/peerstate/pingtracker.go new file mode 100644 index 0000000..3810581 --- /dev/null +++ b/toversok/actors/peerstate/pingtracker.go @@ -0,0 +1,110 @@ +package peerstate + +import ( + "errors" + "net/netip" + "slices" + "sync" + + "github.com/edup2p/common/types" +) + +type PingTracker struct { + rw sync.RWMutex + gotPong map[netip.AddrPort]bool +} + +func NewPingTracker() *PingTracker { + return &PingTracker{ + gotPong: make(map[netip.AddrPort]bool), + } +} + +func (pt *PingTracker) validAPs() []netip.AddrPort { + var aps []netip.AddrPort + + for ap, gotPong := range pt.gotPong { + if gotPong { + aps = append(aps, ap) + } + } + + return aps +} + +func (pt *PingTracker) GotPong(ap netip.AddrPort) { + pt.rw.Lock() + defer pt.rw.Unlock() + + nap := types.NormaliseAddrPort(ap) + pt.gotPong[nap] = true +} + +func (pt *PingTracker) Has(ap netip.AddrPort) bool { + pt.rw.Lock() + defer pt.rw.Unlock() + + nap := types.NormaliseAddrPort(ap) + return pt.gotPong[nap] +} + +func (pt *PingTracker) BestAddrPort() (netip.AddrPort, error) { + pt.rw.RLock() + defer pt.rw.RUnlock() + + aps := pt.validAPs() + if len(aps) == 0 { + return netip.AddrPort{}, errors.New("no valid pings") + } + + slices.SortFunc(aps, gradeAPs) + slices.Reverse(aps) + + return aps[0], nil +} + +const ( + aBetter = 1 + bBetter = -1 + neither = 0 +) + +func gradeAPs(a, b netip.AddrPort) int { + if verCmp := gradeVer(a, b); verCmp != neither { + return verCmp + } + + if privCmp := gradePriv(a, b); privCmp != neither { + return privCmp + } + + return a.Compare(b) +} + +// IPv6 > IPv4 +func gradeVer(ap, bp netip.AddrPort) int { + a := ap.Addr() + b := bp.Addr() + + if a.Is4() && b.Is6() { + return bBetter + } else if a.Is6() && b.Is4() { + return aBetter + } + + return neither +} + +// Private/Unique Local > Non-Private/Unique Global +func gradePriv(ap, bp netip.AddrPort) int { + a := ap.Addr() + b := bp.Addr() + + if a.IsPrivate() && !b.IsPrivate() { + return aBetter + } else if !a.IsPrivate() && b.IsPrivate() { + return bBetter + } + + return neither +} diff --git a/toversok/actors/peerstate/pingtracker_test.go b/toversok/actors/peerstate/pingtracker_test.go new file mode 100644 index 0000000..79118fb --- /dev/null +++ b/toversok/actors/peerstate/pingtracker_test.go @@ -0,0 +1,76 @@ +package peerstate + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + pub4Addr = netip.MustParseAddrPort("8.0.0.1:1337") + pub6Addr = netip.MustParseAddrPort("[2000::1]:1337") + + priv4Addr = netip.MustParseAddrPort("10.0.0.1:1337") + priv6Addr = netip.MustParseAddrPort("[fd00::1]:1337") +) + +func TestPingTracker_FullSelection(t *testing.T) { + pt := NewPingTracker() + + for _, ip := range []netip.AddrPort{ + pub4Addr, + pub6Addr, + + priv4Addr, + priv6Addr, + } { + pt.GotPong(ip) + } + + bap, err := pt.BestAddrPort() + + assert.NoError(t, err) + assert.Equal(t, bap, priv6Addr) +} + +func TestPingTracker_NoPings(t *testing.T) { + pt := NewPingTracker() + + _, err := pt.BestAddrPort() + + assert.Error(t, err) +} + +func TestPingTracker_BestAddrPort(t *testing.T) { + pt := NewPingTracker() + + var bap netip.AddrPort + var err error + + // First add private ip4 + pt.GotPong(priv4Addr) + bap, err = pt.BestAddrPort() + assert.NoError(t, err) + assert.Equal(t, bap, priv4Addr) + + // Then add public ip4, + // this changes nothing + pt.GotPong(pub4Addr) + bap, err = pt.BestAddrPort() + assert.NoError(t, err) + assert.Equal(t, bap, priv4Addr) + + // Then add private ip6 + pt.GotPong(priv6Addr) + bap, err = pt.BestAddrPort() + assert.NoError(t, err) + assert.Equal(t, bap, priv6Addr) + + // Then add public ip6, + // this changes nothing + pt.GotPong(pub6Addr) + bap, err = pt.BestAddrPort() + assert.NoError(t, err) + assert.Equal(t, bap, priv6Addr) +} diff --git a/toversok/actors/peer_state/s_established.go b/toversok/actors/peerstate/s_established.go similarity index 52% rename from toversok/actors/peer_state/s_established.go rename to toversok/actors/peerstate/s_established.go index 33e526e..5e6ee49 100644 --- a/toversok/actors/peer_state/s_established.go +++ b/toversok/actors/peerstate/s_established.go @@ -1,13 +1,14 @@ -package peer_state +package peerstate import ( "context" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" "slices" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) const EstablishedPingInterval = time.Second * 2 @@ -15,6 +16,8 @@ const EstablishedPingInterval = time.Second * 2 type Established struct { *StateCommon + tracker *PingTracker + lastPingRecv time.Time lastPongRecv time.Time @@ -24,11 +27,6 @@ type Established struct { inactive bool inactiveSince time.Time - // TODO: this can flap, - // and basically picks the first best endpoint that the other client responds with, - // which may be non-ideal. - // Tailscale has logic to pick and switch between different endpoints, and sort them. - // We could possibly build this into the state logic. currentOutEndpoint netip.AddrPort knownInEndpoints map[netip.AddrPort]bool @@ -51,13 +49,11 @@ func (e *Established) OnTick() PeerState { if !e.inactive { e.inactive = true e.inactiveSince = time.Now() - } else { - if time.Now().After(e.inactiveSince.Add(ConnectionInactivityTimeout)) { - return LogTransition(e, &Teardown{ - StateCommon: e.StateCommon, - inactive: true, - }) - } + } else if time.Now().After(e.inactiveSince.Add(ConnectionInactivityTimeout)) { + return LogTransition(e, &Teardown{ + StateCommon: e.StateCommon, + inactive: true, + }) } } @@ -79,12 +75,14 @@ func (e *Established) OnTick() PeerState { return nil } -func (e *Established) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(e, ap, clear); s != nil { +func (e *Established) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(e, ap, clearMsg); s != nil { return s } - LogDirectMessage(e, ap, clear) + ap = types.NormaliseAddrPort(ap) + + LogDirectMessage(e, ap, clearMsg) // TODO check if endpoint is same as current used one // - switch? trusting it blindly is open to replay attacks @@ -95,60 +93,86 @@ func (e *Established) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) Peer return nil } - switch m := clear.Message.(type) { - case *msg2.Ping: - if !e.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !e.pingDirectValid(ap, clearMsg.Session, m) { L(e).Warn("dropping invalid ping", "ap", ap.String()) return nil } e.lastPingRecv = time.Now() - e.replyWithPongDirect(ap, clear.Session, m) + e.replyWithPongDirect(ap, clearMsg.Session, m) + + if ap != e.currentOutEndpoint && !e.tracker.Has(ap) { + // We're not sending pings to this, yet we may want to, to prevent asymmetric glare + pi := e.getPeerInfo() + if pi == nil { + // Peer info unavailable + return nil + } + + L(e).Log(context.Background(), types.LevelTrace, "sending ping to ping to prevent assymetric glare", "ap", ap.String(), "current", e.currentOutEndpoint.String()) + + // Send ping with ID, so that it eventually blackholes + e.tm.SendPingDirectWithID(ap, e.peer, pi.Session, m.TxID) + } + return nil - case *msg2.Pong: - e.lastPongRecv = time.Now() - e.ackPongDirect(ap, clear.Session, m) + case *msgsess.Pong: + if err := e.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(e).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + } else { + e.lastPongRecv = time.Now() + e.tracker.GotPong(ap) + e.clearPongDirect(ap, clearMsg.Session, m) + + e.checkChangedPreferredEndpoint() + } + return nil - //case *msg.Rendezvous: default: L(e).Debug("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (e *Established) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(e, relay, peer, clear); s != nil { +func (e *Established) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(e, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(e, relay, peer, clear) + LogRelayMessage(e, relay, peer, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - e.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + e.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - e.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + e.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - //case *msg.Rendezvous: // TODO maybe re-establishment logic? + // case *msg.Rendezvous: default: L(e).Debug("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } +func (e *Established) GetEndpoint() netip.AddrPort { + return e.currentOutEndpoint +} + // canTrustEndpoint returns true if the endpoint that has been given corresponds to the peer. // this will check the current knownInEndpoints, and if it does not exist, will check peerInfo to see if the peer // sent this endpoint in the past with rendezvous. If so, adds it to the knownInEndpoints, and sends a SetAKA. @@ -183,3 +207,33 @@ func (e *Established) canTrustEndpoint(ap netip.AddrPort) bool { return false } + +func (e *Established) checkChangedPreferredEndpoint() { + bap, err := e.tracker.BestAddrPort() + if err != nil { + // this should not happen, at this point we have at least one happy pair + panic(err) + } + + if bap != e.currentOutEndpoint { + L(e).Log(context.Background(), types.LevelTrace, "switching bestaddrport", "bap", bap.String(), "current", e.currentOutEndpoint.String()) + // not the best one, switch + e.switchToEndpoint(bap) + } +} + +func (e *Established) switchToEndpoint(ep netip.AddrPort) { + previous := e.currentOutEndpoint + + e.currentOutEndpoint = ep + + e.tm.OutConnUseAddrPort(e.peer, ep) + e.tm.DManSetAKA(e.peer, ep) + + L(e).Info( + "SWITCHED direct peer connection to better endpoint", + "peer", e.peer.Debug(), + "from", previous.String(), + "to", ep.String(), + ) +} diff --git a/toversok/actors/peer_state/s_inactive.go b/toversok/actors/peerstate/s_inactive.go similarity index 52% rename from toversok/actors/peer_state/s_inactive.go rename to toversok/actors/peerstate/s_inactive.go index 38402ad..21d5198 100644 --- a/toversok/actors/peer_state/s_inactive.go +++ b/toversok/actors/peerstate/s_inactive.go @@ -1,9 +1,10 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type Inactive struct { @@ -25,49 +26,54 @@ func (i *Inactive) OnTick() PeerState { return nil } -func (i *Inactive) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(i, ap, clear); s != nil { +func (i *Inactive) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(i, ap, clearMsg); s != nil { return s } - LogDirectMessage(i, ap, clear) + LogDirectMessage(i, ap, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - if !i.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !i.pingDirectValid(ap, clearMsg.Session, m) { L(i).Warn("dropping invalid ping", "ap", ap.String()) return nil } - i.replyWithPongDirect(ap, clear.Session, m) + i.replyWithPongDirect(ap, clearMsg.Session, m) return nil - case *msg2.Pong: - i.ackPongDirect(ap, clear.Session, m) + case *msgsess.Pong: + if err := i.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(i).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + } else { + i.clearPongDirect(ap, clearMsg.Session, m) + } + return nil default: L(i).Warn("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (i *Inactive) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(i, relay, peer, clear); s != nil { +func (i *Inactive) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(i, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(i, relay, peer, clear) + LogRelayMessage(i, relay, peer, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - i.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + i.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - i.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + i.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Rendezvous: + case *msgsess.Rendezvous: i.tm.Poke() return LogTransition(i, &EstRendezGot{ EstablishingCommon: mkEstComm(i.StateCommon, 0), @@ -77,7 +83,7 @@ func (i *Inactive) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMe L(i).Warn("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } diff --git a/toversok/actors/peer_state/s_t_booting.go b/toversok/actors/peerstate/s_t_booting.go similarity index 77% rename from toversok/actors/peer_state/s_t_booting.go rename to toversok/actors/peerstate/s_t_booting.go index 13beb84..e4e2134 100644 --- a/toversok/actors/peer_state/s_t_booting.go +++ b/toversok/actors/peerstate/s_t_booting.go @@ -1,16 +1,19 @@ -package peer_state +package peerstate import ( + "net/netip" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" - "net/netip" - "time" ) type Booting struct { *StateCommon + tracker *PingTracker + ap netip.AddrPort } @@ -26,6 +29,7 @@ func (b *Booting) OnTick() PeerState { return LogTransition(b, &Established{ StateCommon: b.StateCommon, + tracker: b.tracker, lastPingRecv: time.Now(), lastPongRecv: time.Now(), nextPingDeadline: time.Now(), @@ -35,12 +39,12 @@ func (b *Booting) OnTick() PeerState { }) } -func (b *Booting) OnDirect(ap netip.AddrPort, clear *msgsess.ClearMessage) PeerState { +func (b *Booting) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(b, ap, clear) + return cascadeDirect(b, ap, clearMsg) } -func (b *Booting) OnRelay(relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) PeerState { +func (b *Booting) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(b, relay, peer, clear) + return cascadeRelay(b, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/s_t_teardown.go b/toversok/actors/peerstate/s_t_teardown.go similarity index 66% rename from toversok/actors/peer_state/s_t_teardown.go rename to toversok/actors/peerstate/s_t_teardown.go index 8f7d11f..90ddf72 100644 --- a/toversok/actors/peer_state/s_t_teardown.go +++ b/toversok/actors/peerstate/s_t_teardown.go @@ -1,10 +1,11 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgsess" "net/netip" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type Teardown struct { @@ -28,22 +29,22 @@ func (t *Teardown) OnTick() PeerState { return LogTransition(t, &Inactive{ StateCommon: t.StateCommon, }) - } else { - L(t).Info("LOST direct peer connection", "peer", t.peer.Debug()) - - return LogTransition(t, &Trying{ - StateCommon: t.StateCommon, - tryAt: time.Now(), - }) } + + L(t).Info("LOST direct peer connection", "peer", t.peer.Debug()) + + return LogTransition(t, &Trying{ + StateCommon: t.StateCommon, + tryAt: time.Now(), + }) } -func (t *Teardown) OnDirect(ap netip.AddrPort, clear *msgsess.ClearMessage) PeerState { +func (t *Teardown) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeDirect(t, ap, clear) + return cascadeDirect(t, ap, clearMsg) } -func (t *Teardown) OnRelay(relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) PeerState { +func (t *Teardown) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { // OnTick will transition into the next state regardless, so just pass it along - return cascadeRelay(t, relay, peer, clear) + return cascadeRelay(t, relay, peer, clearMsg) } diff --git a/toversok/actors/peer_state/s_trying.go b/toversok/actors/peerstate/s_trying.go similarity index 50% rename from toversok/actors/peer_state/s_trying.go rename to toversok/actors/peerstate/s_trying.go index 60bbd73..4cb65e6 100644 --- a/toversok/actors/peer_state/s_trying.go +++ b/toversok/actors/peerstate/s_trying.go @@ -1,10 +1,11 @@ -package peer_state +package peerstate import ( - "github.com/edup2p/common/types/key" - msg2 "github.com/edup2p/common/types/msgsess" "net/netip" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type Trying struct { @@ -28,51 +29,55 @@ func (t *Trying) OnTick() PeerState { return nil } -func (t *Trying) OnDirect(ap netip.AddrPort, clear *msg2.ClearMessage) PeerState { - if s := cascadeDirect(t, ap, clear); s != nil { +func (t *Trying) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeDirect(t, ap, clearMsg); s != nil { return s } - LogDirectMessage(t, ap, clear) + LogDirectMessage(t, ap, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - if !t.pingDirectValid(ap, clear.Session, m) { + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + if !t.pingDirectValid(ap, clearMsg.Session, m) { L(t).Warn("dropping invalid ping", "ap", ap.String()) return nil } // TODO(jo): We could start establishing here, possibly. - t.replyWithPongDirect(ap, clear.Session, m) + t.replyWithPongDirect(ap, clearMsg.Session, m) return nil - case *msg2.Pong: - t.ackPongDirect(ap, clear.Session, m) + case *msgsess.Pong: + if err := t.pongDirectValid(ap, clearMsg.Session, m); err != nil { + L(t).Warn("dropping invalid pong", "ap", ap.String(), "err", err) + } else { + t.clearPongDirect(ap, clearMsg.Session, m) + } + return nil - //case *msg.Rendezvous: default: L(t).Warn("ignoring direct session message", "ap", ap, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } } -func (t *Trying) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMessage) PeerState { - if s := cascadeRelay(t, relay, peer, clear); s != nil { +func (t *Trying) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + if s := cascadeRelay(t, relay, peer, clearMsg); s != nil { return s } - LogRelayMessage(t, relay, peer, clear) + LogRelayMessage(t, relay, peer, clearMsg) - switch m := clear.Message.(type) { - case *msg2.Ping: - t.replyWithPongRelay(relay, peer, clear.Session, m) + switch m := clearMsg.Message.(type) { + case *msgsess.Ping: + t.replyWithPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Pong: - t.ackPongRelay(relay, peer, clear.Session, m) + case *msgsess.Pong: + t.ackPongRelay(relay, peer, clearMsg.Session, m) return nil - case *msg2.Rendezvous: + case *msgsess.Rendezvous: return LogTransition(t, &EstRendezGot{ EstablishingCommon: mkEstComm(t.StateCommon, 0), m: m, @@ -81,7 +86,7 @@ func (t *Trying) OnRelay(relay int64, peer key.NodePublic, clear *msg2.ClearMess L(t).Warn("ignoring relay session message", "relay", relay, "peer", peer, - "session", clear.Session, + "session", clearMsg.Session, "msg", m.Debug()) return nil } diff --git a/toversok/actors/peer_state/s_waiting.go b/toversok/actors/peerstate/s_waiting.go similarity index 73% rename from toversok/actors/peer_state/s_waiting.go rename to toversok/actors/peerstate/s_waiting.go index 0556d9b..b4f3cbb 100644 --- a/toversok/actors/peer_state/s_waiting.go +++ b/toversok/actors/peerstate/s_waiting.go @@ -1,10 +1,11 @@ -package peer_state +package peerstate import ( + "net/netip" + "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" - "net/netip" ) type WaitingForInfo struct { @@ -23,23 +24,23 @@ func (w *WaitingForInfo) OnTick() PeerState { return nil } -func (w *WaitingForInfo) OnDirect(ap netip.AddrPort, clear *msgsess.ClearMessage) PeerState { - s := cascadeDirect(w, ap, clear) +func (w *WaitingForInfo) OnDirect(ap netip.AddrPort, clearMsg *msgsess.ClearMessage) PeerState { + s := cascadeDirect(w, ap, clearMsg) if s == nil { // The state did not cascade, so we log here. - LogDirectMessage(w, ap, clear) + LogDirectMessage(w, ap, clearMsg) } return s } -func (w *WaitingForInfo) OnRelay(relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) PeerState { - s := cascadeRelay(w, relay, peer, clear) +func (w *WaitingForInfo) OnRelay(relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) PeerState { + s := cascadeRelay(w, relay, peer, clearMsg) if s == nil { // The state did not cascade, so we log here. - LogRelayMessage(w, relay, peer, clear) + LogRelayMessage(w, relay, peer, clearMsg) } return s diff --git a/toversok/actors/peer_state/util.go b/toversok/actors/peerstate/util.go similarity index 71% rename from toversok/actors/peer_state/util.go rename to toversok/actors/peerstate/util.go index 436aaa2..f479f36 100644 --- a/toversok/actors/peer_state/util.go +++ b/toversok/actors/peerstate/util.go @@ -1,20 +1,21 @@ -package peer_state +package peerstate import ( "context" + "log/slog" + "net/netip" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" - "log/slog" - "net/netip" ) // cascadeDirect makes it so that first we call the default "tick" function of a peer's state, // and if that requests a state transition, call a PeerState.OnDirect with the original arguments, // and return the requested state change with that one if it returns one. -func cascadeDirect(so PeerState, ap netip.AddrPort, clear *msgsess.ClearMessage) (s PeerState) { +func cascadeDirect(so PeerState, ap netip.AddrPort, clearMsg *msgsess.ClearMessage) (s PeerState) { if s1 := so.OnTick(); s1 != nil { - if s2 := s1.OnDirect(ap, clear); s2 != nil { + if s2 := s1.OnDirect(ap, clearMsg); s2 != nil { s = s2 } else { s = s1 @@ -27,9 +28,9 @@ func cascadeDirect(so PeerState, ap netip.AddrPort, clear *msgsess.ClearMessage) // cascadeRelay makes it so that first we call the default "tick" function of a peer's state, // and if that requests a state transition, call a PeerState.OnRelay with the original arguments, // and return the requested state change with that one if it returns one. -func cascadeRelay(so PeerState, relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) (s PeerState) { +func cascadeRelay(so PeerState, relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) (s PeerState) { if s1 := so.OnTick(); s1 != nil { - if s2 := s1.OnRelay(relay, peer, clear); s2 != nil { + if s2 := s1.OnRelay(relay, peer, clearMsg); s2 != nil { s = s2 } else { s = s1 @@ -44,27 +45,27 @@ func L(s PeerState) *slog.Logger { return slog.With("peer", s.Peer().Debug(), "state", s.Name()) } -func LogTransition(from PeerState, to PeerState) PeerState { +func LogTransition(from, to PeerState) PeerState { L(from).Log(context.Background(), types.LevelTrace, "transitioning state", "to-state", to.Name()) return to } -func LogDirectMessage(s PeerState, ap netip.AddrPort, clear *msgsess.ClearMessage) { +func LogDirectMessage(s PeerState, ap netip.AddrPort, clearMsg *msgsess.ClearMessage) { L(s).Log(context.Background(), types.LevelTrace, "received direct message", slog.Group("from", "addrport", ap, - "session", clear.Session.Debug()), - "msg", clear.Message.Debug(), + "session", clearMsg.Session.Debug()), + "msg", clearMsg.Message.Debug(), ) } -func LogRelayMessage(s PeerState, relay int64, peer key.NodePublic, clear *msgsess.ClearMessage) { +func LogRelayMessage(s PeerState, relay int64, peer key.NodePublic, clearMsg *msgsess.ClearMessage) { L(s).Log(context.Background(), types.LevelTrace, "received relay message", slog.Group("from", "relay", relay, "peer", peer.Debug(), - "session", clear.Session), - "msg", clear.Message.Debug(), + "session", clearMsg.Session), + "msg", clearMsg.Message.Debug(), ) } diff --git a/toversok/actors/rehearsal_test.go b/toversok/actors/rehearsal_test.go index 11dc439..5828849 100644 --- a/toversok/actors/rehearsal_test.go +++ b/toversok/actors/rehearsal_test.go @@ -10,6 +10,7 @@ import ( "github.com/edup2p/common/types/msgactor" ) +//nolint:unused type MockActor struct { ctx context.Context diff --git a/toversok/actors/stage.go b/toversok/actors/stage.go index 1169cfd..8f1a79a 100644 --- a/toversok/actors/stage.go +++ b/toversok/actors/stage.go @@ -3,42 +3,35 @@ package actors import ( "context" "errors" + "log/slog" + "net" + "net/netip" + "reflect" + "slices" + "sync" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgcontrol" "github.com/edup2p/common/types/relay" + "github.com/edup2p/common/types/relay/relayhttp" "github.com/edup2p/common/types/stage" "golang.org/x/exp/maps" - "log/slog" - "net" - "net/netip" - "reflect" - "slices" - "sync" - "time" ) type OutConnActor interface { ifaces.Actor - - Ctx() context.Context } type InConnActor interface { ifaces.Actor - Ctx() context.Context - ForwardPacket(pkt []byte) } -//udp, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), localPort))) -//if err != nil { -// panic(fmt.Sprintf("could not create listenUDP: %s", err)) -//} - func MakeStage( pCtx context.Context, @@ -48,11 +41,20 @@ func MakeStage( bindExt func() types.UDPConn, bindLocal func(peer key.NodePublic) types.UDPConn, controlSession ifaces.ControlInterface, + + dialRelayFunc relayhttp.RelayDialFunc, + + wgIf *net.Interface, ) ifaces.Stage { - ctx := context.WithoutCancel(pCtx) + if dialRelayFunc == nil { + dialRelayFunc = relayhttp.Dial + } + + ctx, cancel := context.WithCancel(pCtx) s := &Stage{ - Ctx: ctx, + Ctx: ctx, + cancel: cancel, connMutex: sync.RWMutex{}, inConn: make(map[key.NodePublic]InConnActor), @@ -69,6 +71,10 @@ func MakeStage( ext: bindExt(), bindLocal: bindLocal, control: controlSession, + + wgIf: wgIf, + + dialRelayFunc: dialRelayFunc, } s.DMan = s.makeDM(s.ext) @@ -80,15 +86,38 @@ func MakeStage( s.TMan = s.makeTM() s.SMan = s.makeSM(sessPriv) s.EMan = s.makeEM() + s.MMan = s.makeMM() + + s.installAfterFunc() return s } +func (s *Stage) installAfterFunc() { + context.AfterFunc(s.Ctx, s.Close) + + // TODO: self-heal. Currently these just cancel the stage, which then propagates back upwards, but we should figure + // out if its possible to heal components. + + context.AfterFunc(s.DMan.Ctx(), s.cancel) + context.AfterFunc(s.DRouter.Ctx(), s.cancel) + + context.AfterFunc(s.RMan.Ctx(), s.cancel) + context.AfterFunc(s.RRouter.Ctx(), s.cancel) + + context.AfterFunc(s.TMan.Ctx(), s.cancel) + context.AfterFunc(s.SMan.Ctx(), s.cancel) + context.AfterFunc(s.EMan.Ctx(), s.cancel) + context.AfterFunc(s.MMan.Ctx(), s.cancel) +} + // Stage for the Actors type Stage struct { // The parent context of the stage that all actors must parent Ctx context.Context + cancel context.CancelFunc + // The DirectManager DMan ifaces.DirectManagerActor // The DirectRouter @@ -105,13 +134,17 @@ type Stage struct { SMan ifaces.SessionManagerActor // The EndpointManager EMan ifaces.EndpointManagerActor + // The MDNSManager + MMan ifaces.MDNSManagerActor connMutex sync.RWMutex inConn map[key.NodePublic]InConnActor outConn map[key.NodePublic]OutConnActor - getNodePriv func() *key.NodePrivate - getSessPriv func() *key.SessionPrivate + getNodePriv func() *key.NodePrivate + getSessPriv func() *key.SessionPrivate + + endpointMutex sync.RWMutex localEndpoints []netip.AddrPort stunEndpoints []netip.AddrPort @@ -122,15 +155,20 @@ type Stage struct { control ifaces.ControlInterface + wgIf *net.Interface + //// A repeatable function to an outside context to acquire a new UDPconn, //// once a peer conn has died for whatever reason. - //reviveOutConn func(peer key.NodePublic) *net.UDPConn + // TODO rework this? + // reviveOutConn func(peer key.NodePublic) *net.UDPConn // - //makeOutConn func(udp UDPConn, peer key.NodePublic, s *Stage) OutConnActor - //makeInConn func(udp UDPConn, peer key.NodePublic, s *Stage) InConnActor + // makeOutConn func(udp UDPConn, peer key.NodePublic, s *Stage) OutConnActor + // makeInConn func(udp UDPConn, peer key.NodePublic, s *Stage) InConnActor ext types.UDPConn bindLocal func(peer key.NodePublic) types.UDPConn + + dialRelayFunc relayhttp.RelayDialFunc } // Start kicks off goroutines for the stage and returns @@ -144,6 +182,7 @@ func (s *Stage) Start() { go s.TMan.Run() go s.SMan.Run() go s.EMan.Run() + go s.MMan.Run() go s.DMan.Run() go s.DRouter.Run() @@ -154,6 +193,12 @@ func (s *Stage) Start() { s.started = true } +func (s *Stage) Close() { + if err := s.ext.Close(); err != nil { + slog.Error("error closing ext for stage", "err", err) + } +} + // Watchdog will be run to constantly check for faults on the stage and repair them. func (s *Stage) Watchdog() { ticker := time.NewTicker(time.Second * 5) @@ -222,7 +267,7 @@ func (s *Stage) reapableConnsLocked() []key.NodePublic { if !ok { // outconn is gone for some reason, this is fine for now - // TODO log this? + slog.Warn("missing outconn pair to inconn, this is fine, but odd", "peer", peer.Debug()) } else { out.Cancel() } @@ -256,7 +301,7 @@ func (s *Stage) reapableConnsLocked() []key.NodePublic { return peers } -func (s *Stage) syncableConnsLocked() (added []key.NodePublic, deleted []key.NodePublic) { +func (s *Stage) syncableConnsLocked() (added, deleted []key.NodePublic) { piPeers := maps.Keys(s.peerInfo) connPeers := types.SetUnion(maps.Keys(s.inConn), maps.Keys(s.outConn)) @@ -343,19 +388,12 @@ func (s *Stage) InConnFor(peer key.NodePublic) InConnActor { return s.inConn[peer] } -//// AddConn creates an InConn and OutConn for a specified connection. -//// Starting each Actor'S goroutines as well. It also starts a SockRecv given the -//// udp connection. -//func (s *Stage) AddConn(udp *net.UDPConn, peer key.NodePublic, info *PeerInfo) { -// s.UpdateSessionKey(peer, session) -// s.addConn(udp, peer, homeRelay) -//} - // addConnLocked assumes Stage.connMutex and Stage.peerInfoMutex is held by caller. func (s *Stage) addConnLocked(peer key.NodePublic, udp types.UDPConn) { pi := s.peerInfo[peer] if pi == nil { + // We run this with the assumption that peerinfo has been given to us panic("expecting to have peer information at this point") } @@ -370,15 +408,15 @@ func (s *Stage) addConnLocked(peer key.NodePublic, udp types.UDPConn) { } func (s *Stage) GetEndpoints() []netip.AddrPort { - s.connMutex.RLock() - defer s.connMutex.RUnlock() + s.endpointMutex.RLock() + defer s.endpointMutex.RUnlock() return slices.Concat(s.localEndpoints, s.stunEndpoints) } func (s *Stage) setSTUNEndpoints(endpoints []netip.AddrPort) { - s.connMutex.Lock() - defer s.connMutex.Unlock() + s.endpointMutex.Lock() + defer s.endpointMutex.Unlock() sortEndpointSlice(endpoints) @@ -393,8 +431,8 @@ func (s *Stage) setSTUNEndpoints(endpoints []netip.AddrPort) { } func (s *Stage) setLocalEndpoints(addrs []netip.Addr) { - s.connMutex.RLock() - defer s.connMutex.RUnlock() + s.endpointMutex.Lock() + defer s.endpointMutex.Unlock() localPort := s.getLocalPort() @@ -406,6 +444,7 @@ func (s *Stage) setLocalEndpoints(addrs []netip.Addr) { var endpoints []netip.AddrPort + // Filter own endpoint, and also append localport for _, addr := range addrs { if s.control.IPv4().Contains(addr) || s.control.IPv6().Contains(addr) { continue @@ -428,6 +467,15 @@ func (s *Stage) setLocalEndpoints(addrs []netip.Addr) { s.notifyEndpointChanged() } +func (s *Stage) getLocalEndpoints() []netip.Addr { + s.endpointMutex.RLock() + defer s.endpointMutex.RUnlock() + + return types.Map(s.localEndpoints, func(t netip.AddrPort) netip.Addr { + return t.Addr() + }) +} + func (s *Stage) getLocalPort() uint16 { type HasLocalAddr interface { LocalAddr() net.Addr @@ -461,7 +509,7 @@ func (s *Stage) notifyEndpointChanged() { } } -func (s *Stage) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, _ netip.Addr, _ netip.Addr, prop msgcontrol.Properties) error { +func (s *Stage) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, ip4, ip6 netip.Addr, prop msgcontrol.Properties) error { s.peerInfoMutex.Lock() defer func() { @@ -480,6 +528,9 @@ func (s *Stage) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip. Endpoints: types.NormaliseAddrPortSlice(endpoints), RendezvousEndpoints: make([]netip.AddrPort, 0), Session: session, + IPv4: ip4, + IPv6: ip6, + MDNS: prop.MDNS, } return nil @@ -498,6 +549,9 @@ func (s *Stage) UpdatePeer(peer key.NodePublic, homeRelay *int64, endpoints []ne if session != nil { info.Session = *session } + if prop != nil { + info.MDNS = prop.MDNS + } }) } @@ -547,6 +601,19 @@ func (s *Stage) GetPeerInfo(peer key.NodePublic) *stage.PeerInfo { return s.peerInfo[peer] } +func (s *Stage) GetPeersWhere(f func(key.NodePublic, *stage.PeerInfo) bool) []key.NodePublic { + s.peerInfoMutex.RLock() + defer s.peerInfoMutex.RUnlock() + + var peers []key.NodePublic + for peer, info := range s.peerInfo { + if f(peer, info) { + peers = append(peers, peer) + } + } + return peers +} + func (s *Stage) RemovePeer(peer key.NodePublic) error { s.peerInfoMutex.Lock() delete(s.peerInfo, peer) @@ -571,40 +638,6 @@ func (s *Stage) ControlSTUN() []netip.AddrPort { return []netip.AddrPort{} } -//func (s *Stage) RemoveConn(peer key.NodePublic) { -// s.connMutex.Lock() -// defer s.connMutex.Unlock() -// -// in, inok := s.inConn[peer] -// out, outok := s.inConn[peer] -// -// if !inok && !outok { -// // both already removed, we're done here -// return -// } -// -// if inok != outok { -// // only one of them removed? -// // we could recover this, but this is a bug, panic. -// panic(fmt.Sprintf("InConn or OutConn presence on stage was disbalanced: in=%t, out=%t", inok, outok)) -// } -// -// // Now we know both exist -// -// delete(s.inConn, peer) -// delete(s.outConn, peer) -// -// in.Cancel() -// out.Cancel() -// -// // OutConn cancel: -// // this closes the outch in SockRecv, -// // sends "outconn goodbye" to traffic manager, -// // -// // InConn cancel: -// // sends "outconn goodbye" to traffic manager. -// -// // When TM has received both goodbyes: -// // removes from internal activity tracking, -// // and removes mapping from direct router. -//} +func (s *Stage) Context() context.Context { + return s.Ctx +} diff --git a/toversok/actors/util.go b/toversok/actors/util.go index 4da567d..10907da 100644 --- a/toversok/actors/util.go +++ b/toversok/actors/util.go @@ -3,12 +3,14 @@ package actors import ( "context" "fmt" - "github.com/edup2p/common/types/ifaces" - "github.com/edup2p/common/types/msgactor" "log/slog" "net/netip" "sort" "sync/atomic" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/ifaces" + "github.com/edup2p/common/types/msgactor" ) // RunCheck ensures that only one instance of the actor is running at all times. @@ -37,14 +39,16 @@ func L(a ifaces.Actor) *slog.Logger { } func bail(c context.Context, v any) { - maybeCcc := c.Value("ccc") + maybeCcc := c.Value(types.CCC) if maybeCcc == nil { + // We add the CCC early in the engine's lifecycle, so this shouldn't happen. panic(fmt.Errorf("could not bail, cannot find ccc: %s", v)) } probablyCcc, ok := maybeCcc.(context.CancelCauseFunc) if !ok { + // Ditto, if we add it, we make sure its added correctly panic(fmt.Errorf("could not bail, ccc is not CancelCauseFunc: %s", v)) } @@ -56,3 +60,9 @@ func sortEndpointSlice(endpoints []netip.AddrPort) { return endpoints[i].Addr().Less(endpoints[j].Addr()) && endpoints[i].Port() < endpoints[j].Port() }) } + +func assureClose[T ifaces.Actor](a T) T { + context.AfterFunc(a.Ctx(), a.Close) + + return a +} diff --git a/toversok/actors/util_test.go b/toversok/actors/util_test.go index 67de85a..473cf6f 100644 --- a/toversok/actors/util_test.go +++ b/toversok/actors/util_test.go @@ -8,17 +8,23 @@ import ( ) // Test constants -const assertEventuallyTick time.Duration = 1 * time.Millisecond -const assertEventuallyTimeout time.Duration = 10 * assertEventuallyTick +const ( + assertEventuallyTick = 1 * time.Millisecond + assertEventuallyTimeout = 10 * assertEventuallyTick +) // Test variables -var dummyAddr netip.Addr = netip.IPv4Unspecified() -var dummyAddrPort netip.AddrPort = netip.AddrPortFrom(dummyAddr, 0) -var dummyKey key.NodePublic = [32]byte{0} +var ( + dummyAddr = netip.IPv4Unspecified() + dummyAddrPort = netip.AddrPortFrom(dummyAddr, 0) + dummyKey key.NodePublic = [32]byte{0} +) // Test session -var testPriv key.SessionPrivate = key.NewSession() -var testPub key.SessionPublic = testPriv.Public() +var ( + testPriv = key.NewSession() + testPub = testPriv.Public() +) func getTestPriv() *key.SessionPrivate { return &testPriv diff --git a/toversok/control_conn.go b/toversok/control_conn.go index 87cfb30..d04d6ad 100644 --- a/toversok/control_conn.go +++ b/toversok/control_conn.go @@ -4,6 +4,11 @@ import ( "context" "errors" "fmt" + "log/slog" + "net/netip" + "sync" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/control" "github.com/edup2p/common/types/control/controlhttp" @@ -12,10 +17,6 @@ import ( "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgcontrol" "golang.org/x/exp/maps" - "log/slog" - "net/netip" - "sync" - "time" ) type DefaultControlHost struct { @@ -49,6 +50,7 @@ type ResumableControlSession struct { // Airlifted out of Client, expected to stay the same as long as the session does ipv4 netip.Prefix ipv6 netip.Prefix + expiry time.Time controlKey key.ControlPublic session string @@ -66,21 +68,23 @@ type ResumableControlSession struct { // In to local msgInQueue []msgcontrol.ControlMessage - callbacks ifaces.ControlCallbacks + callbackLock sync.RWMutex + callbacks ifaces.ControlCallbacks } func CreateControlSession(ctx context.Context, opts dial.Opts, controlKey key.ControlPublic, getPriv func() *key.NodePrivate, getSess func() *key.SessionPrivate, logon types.LogonCallback) (*ResumableControlSession, error) { - // TODO authCallback func(url string) - rcsCtx, rcsCcc := context.WithCancelCause(ctx) - clientCtx := context.WithoutCancel(rcsCtx) - c, err := controlhttp.Dial(clientCtx, opts, getPriv, getSess, controlKey, nil, logon) + c, err := controlhttp.Dial(rcsCtx, opts, getPriv, getSess, controlKey, nil, logon) if err != nil { + rcsCcc(err) return nil, fmt.Errorf("could not create control session: %w", err) } - slog.Debug("created initial control connection") + slog.Debug( + "created initial control connection", + "ipv4", c.IPv4.String(), "ipv6", c.IPv6.String(), "expiry", c.Expiry, + ) rcs := &ResumableControlSession{ ctx: rcsCtx, @@ -88,6 +92,7 @@ func CreateControlSession(ctx context.Context, opts dial.Opts, controlKey key.Co ipv4: c.IPv4, ipv6: c.IPv6, + expiry: c.Expiry, controlKey: c.ControlKey, session: *c.SessionID, @@ -106,7 +111,6 @@ func CreateControlSession(ctx context.Context, opts dial.Opts, controlKey key.Co } func (rcs *ResumableControlSession) Run() { - go func() { <-rcs.ctx.Done() @@ -129,7 +133,6 @@ func (rcs *ResumableControlSession) Run() { } err := rcs.FlushOut() - if err != nil { slog.Warn("control connection errored while flushing out", "err", err) @@ -140,7 +143,6 @@ func (rcs *ResumableControlSession) Run() { if types.IsContextDone(rcs.ctx) { slog.Info("control session ended, closing client") - rcs.client.Close() return } @@ -175,7 +177,7 @@ func (rcs *ResumableControlSession) Run() { absenceStart := time.Now() - var session = &rcs.session + session := &rcs.session var err error var client *control.Client @@ -186,14 +188,12 @@ func (rcs *ResumableControlSession) Run() { return } - clientCtx := context.WithoutCancel(rcs.ctx) - client, err = controlhttp.Dial( - clientCtx, rcs.clientOpts, rcs.getPriv, rcs.getSess, rcs.controlKey, session, nil, + rcs.ctx, rcs.clientOpts, rcs.getPriv, rcs.getSess, rcs.controlKey, session, nil, ) - var r = msgcontrol.NoRetryStrategy - var retry = &r + r := msgcontrol.NoRetryStrategy + retry := &r if err != nil { if errors.As(err, retry) { @@ -211,7 +211,7 @@ func (rcs *ResumableControlSession) Run() { return } - if errors.Is(err, control.NeedsLogonError) { + if errors.Is(err, control.ErrNeedsLogon) { // TODO dead/retry logic, signal that session is dead and needs manual logon panic("not implemented") } @@ -220,6 +220,27 @@ func (rcs *ResumableControlSession) Run() { continue } + if rcs.ipv4 != client.IPv4 { + slog.Error("control-given IPv4 prefix is different than cached IPv4, bailing...", "cached", rcs.ipv4, "given", client.IPv4) + rcs.ccc(fmt.Errorf("IPv4 changed from %s to %s", rcs.ipv4, client.IPv4)) + + return + } + + if rcs.ipv6 != client.IPv6 { + slog.Error("control-given IPv6 prefix is different than cached IPv6, bailing...", "cached", rcs.ipv6, "given", client.IPv6) + rcs.ccc(fmt.Errorf("IPv6 changed from %s to %s", rcs.ipv6, client.IPv6)) + + return + } + + if rcs.expiry != client.Expiry { + slog.Error("control-given expiry is different than cached expiry, bailing...", "cached", rcs.expiry, "given", client.Expiry) + rcs.ccc(fmt.Errorf("expiry changed from %s to %s", rcs.expiry, client.Expiry)) + + return + } + slog.Debug("resumed control connection") break @@ -233,13 +254,15 @@ func (rcs *ResumableControlSession) Run() { } } +var ErrDisconnected = errors.New("control requested disconnect") + func (rcs *ResumableControlSession) Handle(msg msgcontrol.ControlMessage) error { slog.Debug("Handle", "msg", msg) switch m := msg.(type) { case *msgcontrol.PeerAddition: rcs.knownPeers[m.PubKey] = true - return rcs.callbacks.AddPeer( + return rcs.ExpectCallbacks().AddPeer( m.PubKey, m.HomeRelay, m.Endpoints, @@ -254,7 +277,7 @@ func (rcs *ResumableControlSession) Handle(msg msgcontrol.ControlMessage) error endpoints = m.Endpoints } - return rcs.callbacks.UpdatePeer( + return rcs.ExpectCallbacks().UpdatePeer( m.PubKey, m.HomeRelay, endpoints, @@ -263,16 +286,21 @@ func (rcs *ResumableControlSession) Handle(msg msgcontrol.ControlMessage) error ) case *msgcontrol.PeerRemove: delete(rcs.knownPeers, m.PubKey) - return rcs.callbacks.RemovePeer(m.PubKey) + return rcs.ExpectCallbacks().RemovePeer(m.PubKey) case *msgcontrol.RelayUpdate: - return rcs.callbacks.UpdateRelays(m.Relays) + return rcs.ExpectCallbacks().UpdateRelays(m.Relays) + case *msgcontrol.Disconnect: + rcs.client.Cancel(fmt.Errorf("received disconnect: %w, %w", ErrDisconnected, m.RetryStrategy)) + return nil default: return fmt.Errorf("got unexpected message from control: %v", msg) } - } func (rcs *ResumableControlSession) CallbacksReady() bool { + rcs.callbackLock.RLock() + defer rcs.callbackLock.RUnlock() + return rcs.callbacks != nil } @@ -334,7 +362,30 @@ func (rcs *ResumableControlSession) IPv6() netip.Prefix { return rcs.ipv6 } +func (rcs *ResumableControlSession) Expiry() time.Time { + return rcs.expiry +} + +func (rcs *ResumableControlSession) ExpectCallbacks() ifaces.ControlCallbacks { + rcs.callbackLock.RLock() + defer rcs.callbackLock.RUnlock() + + if rcs.callbacks == nil { + // Part of the function contract; if it doesnt exist, it'll blow up + panic("expected callbacks to be ready at this stage") + } + + return rcs.callbacks +} + +func (rcs *ResumableControlSession) Context() context.Context { + return rcs.ctx +} + func (rcs *ResumableControlSession) InstallCallbacks(callbacks ifaces.ControlCallbacks) { + rcs.callbackLock.Lock() + defer rcs.callbackLock.Unlock() + rcs.callbacks = callbacks } @@ -347,7 +398,7 @@ func (rcs *ResumableControlSession) send(msg msgcontrol.ControlMessage) error { return nil } - if !errors.Is(err, control.ClosedErr) { + if !errors.Is(err, control.ErrClosed) { return err } } diff --git a/toversok/engine.go b/toversok/engine.go index 8f7c5be..3e2026c 100644 --- a/toversok/engine.go +++ b/toversok/engine.go @@ -4,38 +4,16 @@ import ( "context" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/ifaces" - "github.com/edup2p/common/types/key" "log/slog" "net" "net/netip" "sync" "time" -) -//type EngineOptions struct { -// //Ctx context.Context -// //Ccc context.CancelCauseFunc -// // -// //PrivKey key.NakedKey -// // -// //Control dial.Opts -// //ControlKey key.ControlPublic -// // -// //// Do not contact control -// //OverrideControl bool -// //OverrideIPv4 netip.Prefix -// //OverrideIPv6 netip.Prefix -// -// WG WireGuardHost -// FW FirewallHost -// Co ControlHost -// -// ExtBindPort uint16 -// -// PrivateKey key.NodePrivate -//} + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/control" + "github.com/edup2p/common/types/key" +) // Engine is the main and most high-level object for any client implementation. // @@ -45,6 +23,9 @@ type Engine struct { ctx context.Context ccc context.CancelCauseFunc + runningCtx context.Context + runningCancel context.CancelFunc + sess *Session extBind *types.UDPConnCloseCatcher @@ -56,9 +37,10 @@ type Engine struct { nodePriv key.NodePrivate - state stateObserver - doAutoRestart bool - dirty bool + state stateObserver + dirty bool + + deviceKey *string } // Start will fire up the Engine. @@ -69,8 +51,13 @@ type Engine struct { // - Reason for any other startup error. // // After the engine has successfully started once, it will automatically restart on any failure. -func (e *Engine) Start() error { - return e.start(true) +func (e *Engine) Start() (context.Context, error) { + err := e.start(true) + if err != nil { + return nil, err + } + + return e.runningCtx, nil } func (e *Engine) start(allowLogon bool) error { @@ -83,23 +70,19 @@ func (e *Engine) start(allowLogon bool) error { return errors.New("cannot start; already running") } + if e.runningCtx != nil { + e.runningCancel() + } + if e.sess != nil && e.sess.ctx.Err() == nil { // Session is still running, even though that shouldn't be the case, as we checked for NoSession above e.sess.ccc(errors.New("engine state desynced, shutting down")) } - if e.dirty { - if err := e.wg.Reset(); err != nil { - e.slog().Error("dirty start: could not reset wireguard", "err", err) - e.state.set(NoSession) - return err - } + e.runningCtx, e.runningCancel = context.WithCancel(e.ctx) - if err := e.fw.Reset(); err != nil { - e.slog().Error("dirty start: could not reset firewall", "err", err) - e.state.set(NoSession) - return err - } + if err := e.maybeClean(); err != nil { + return fmt.Errorf("engine state cleaning failed: %w", err) } e.dirty = true @@ -115,6 +98,34 @@ func (e *Engine) Context() context.Context { return e.ctx } +func (e *Engine) RunningContext() context.Context { + if e.runningCtx != nil && e.runningCtx.Err() != nil { + return nil + } + + return e.runningCtx +} + +func (e *Engine) maybeClean() error { + slog.Debug("maybeClean called", "dirty", e.dirty) + + if e.dirty { + if err := e.wg.Reset(); err != nil { + e.slog().Error("clean: could not reset wireguard", "err", err) + e.state.set(NoSession) + return err + } + + if err := e.fw.Reset(); err != nil { + e.slog().Error("clean: could not reset firewall", "err", err) + e.state.set(NoSession) + return err + } + } + + return nil +} + // StalledEngineRestartInterval represents how many seconds to wait before restarting an engine, // after it has stalled/failed. const StalledEngineRestartInterval = time.Second * 2 @@ -122,9 +133,20 @@ const StalledEngineRestartInterval = time.Second * 2 func (e *Engine) autoRestart() { if e.WillRestart() { if err := e.start(false); err != nil { + if errors.Is(err, control.ErrNeedsLogon) { + // Bail, we can't do anything here + e.runningCancel() + } + slog.Info("autoRestart: will retry in 10 seconds") time.AfterFunc(StalledEngineRestartInterval, e.autoRestart) } + } else { + slog.Debug("will not auto-restart") + + if err := e.maybeClean(); err != nil { + slog.Error("engine state cleaning failed", "err", err) + } } } @@ -135,11 +157,7 @@ func (e *Engine) Stop() { return } - e.doAutoRestart = false - - if e.sess.ctx.Err() != nil { - e.sess.ccc(errors.New("shutting down")) - } + e.runningCancel() var stillDirty bool @@ -166,24 +184,31 @@ func (e *Engine) installSession(allowLogon bool) error { var logon types.LogonCallback if allowLogon { - logon = func(url string, _ chan<- string) error { - // TODO register/use device key channel + logon = func(url string, devKeyCh chan<- string) error { + e.state.alter(func(o *stateObserver) { + o.loginURL = url + o.loginDeviceKeyCh = devKeyCh + }) - e.state.currentLoginUrl = url e.state.change(CreatingSession, NeedsLogin) return nil } } var err error - e.sess, err = SetupSession(e.ctx, e.wg, e.fw, e.co, e.getExtConn, e.getNodePriv, logon) + e.sess, err = SetupSession(e.runningCtx, e.wg, e.fw, e.co, e.getExtConn, e.getNodePriv, logon) if err != nil { return fmt.Errorf("failed to setup session: %w", err) } + e.state.alter(func(o *stateObserver) { + o.expiry = e.sess.cs.Expiry() + }) + if !(e.state.change(CreatingSession, Established) || e.state.change(NeedsLogin, Established)) { - e.ccc(errors.New("incorrect state transition")) - panic("incorrect state transition to established") + err = errors.New("incorrect state transition") + e.ccc(err) + return err } context.AfterFunc(e.sess.ctx, func() { @@ -198,7 +223,7 @@ func (e *Engine) installSession(allowLogon bool) error { // WillRestart says whether the engine strives to be in a running state. func (e *Engine) WillRestart() bool { - return e.doAutoRestart + return e.runningCtx != nil && e.runningCtx.Err() != nil } func (e *Engine) slog() *slog.Logger { @@ -210,10 +235,13 @@ func newStateObserver() stateObserver { } type stateObserver struct { - mu sync.Mutex - state EngineState - currentLoginUrl string - callbacks []func(state EngineState) + mu sync.Mutex + state EngineState + callbacks []func(state EngineState) + + loginURL string + loginDeviceKeyCh chan<- string + expiry time.Time } func (s *stateObserver) CurrentState() EngineState { @@ -227,22 +255,28 @@ func (s *stateObserver) RegisterStateChangeListener(f func(state EngineState)) { s.callbacks = append(s.callbacks, f) } -var WrongStateErr = errors.New("wrong state") +var ErrWrongState = errors.New("wrong state") -func (s *stateObserver) GetNeedsLoginState() (url string, err error) { +func (s *stateObserver) GetNeedsLoginState() (url string, devKeyCh chan<- string, err error) { s.mu.Lock() defer s.mu.Unlock() - if s.state == NeedsLogin { - return s.currentLoginUrl, nil - } else { - return "", WrongStateErr + if s.state != NeedsLogin { + return "", nil, ErrWrongState } + + return s.loginURL, s.loginDeviceKeyCh, nil } -func (s *stateObserver) GetEstablishedState() { - //TODO implement me - panic("implement me") +func (s *stateObserver) GetEstablishedState() (time.Time, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != Established { + return time.Time{}, ErrWrongState + } + + return s.expiry, nil } func (s *stateObserver) change(oldState, newState EngineState) bool { @@ -273,6 +307,13 @@ func (s *stateObserver) set(newState EngineState) { s.asyncFireCallbacks(newState) } +func (s *stateObserver) alter(f func(observer *stateObserver)) { + s.mu.Lock() + defer s.mu.Unlock() + + f(s) +} + func (s *stateObserver) asyncFireCallbacks(state EngineState) { for _, cb := range s.callbacks { go cb(state) @@ -298,18 +339,19 @@ func NewEngine( parentCtx = context.Background() } - ctx, ccc := context.WithCancelCause(parentCtx) - - if wg == nil { + switch { + case wg == nil: return nil, errors.New("cannot initialise toversok engine with nil WireGuardHost") - } else if fw == nil { + case fw == nil: return nil, errors.New("cannot initialise toversok engine with nil FirewallHost") - } else if co == nil { + case co == nil: return nil, errors.New("cannot initialise toversok engine with nil ControlHost") - } else if privateKey.IsZero() { + case privateKey.IsZero(): return nil, errors.New("cannot initialise toversok engine with zero privateKey") } + ctx, ccc := context.WithCancelCause(parentCtx) + e := &Engine{ ctx: ctx, ccc: ccc, @@ -328,12 +370,32 @@ func NewEngine( e.Observer().RegisterStateChangeListener(func(state EngineState) { if state == NeedsLogin { - url, err := e.Observer().GetNeedsLoginState() + url, devKeyCh, err := e.Observer().GetNeedsLoginState() if err == nil { e.slog().Info("control wants logon", "url", url) } else { e.slog().Error("could not get login state when prompted for it", "err", err) } + + if e.deviceKey != nil { + devKeyCh <- *e.deviceKey + } + } else if state == Established { + expiry, err := e.Observer().GetEstablishedState() + if err != nil { + // We are literally in the established state, we can get the GetEstablishedState + // There is one tiny window where it has flopped back, and so just ignore that if that is the case + return + } + if expiry != (time.Time{}) { + slog.Info("established session with expiry", "expiry", expiry, "in", time.Until(expiry)) + } + } + }) + + context.AfterFunc(e.ctx, func() { + if err := e.maybeClean(); err != nil { + slog.Error("after-ctx: engine state cleaning failed", "err", err) } }) @@ -347,8 +409,8 @@ func (e *Engine) getNodePriv() *key.NodePrivate { func (e *Engine) getExtConn() types.UDPConn { if e.extBind == nil || e.extBind.Closed { conn, err := e.bindExt() - if err != nil { + // We expect the bindext to work, else we more or less just can't do anything panic(fmt.Sprintf("could not bind ext: %s", err)) } @@ -373,118 +435,10 @@ func (e *Engine) Observer() Observer { return &e.state } +// SupplyDeviceKey gives the device key that'll be used when logging on. +// This must be called BEFORE Start. func (e *Engine) SupplyDeviceKey(key string) error { - // TODO - panic("not implemented") -} - -// -//const WGKeepAlive = time.Second * 20 -// -//func (e *Engine) Handle(ev Event) error { -// switch ev := ev.(type) { -// case PeerAddition: -// return e.AddPeer(ev.Key, ev.HomeRelayId, ev.Endpoints, ev.SessionKey, ev.VIPs.IPv4, ev.VIPs.IPv6) -// case PeerUpdate: -// // FIXME the reason for the panic below is because this function is essentially deprecated, and it still uses -// // gonull, which is a pain -// panic("cannot handle PeerUpdate via handle") -// -// //if ev.Endpoints.Present { -// // if err := e.stage.SetEndpoints(ev.Key, ev.Endpoints.Val); err != nil { -// // return fmt.Errorf("failed to update endpoints: %w", err) -// // } -// //} -// // -// //if ev.SessionKey.Present { -// // if err := e.stage.UpdateSessionKey(ev.Key, ev.SessionKey.Val); err != nil { -// // return fmt.Errorf("failed to update session key: %w", err) -// // } -// //} -// // -// //if ev.HomeRelayId.Present { -// // if err := e.stage.UpdateHomeRelay(ev.Key, ev.HomeRelayId.Val); err != nil { -// // return fmt.Errorf("failed to update home relay: %w", err) -// // } -// //} -// case PeerRemoval: -// return e.RemovePeer(ev.Key) -// case RelayUpdate: -// return e.UpdateRelays(ev.Set) -// default: -// // TODO warn-log about unknown type instead of panic -// panic("Unknown type!") -// } -// -// return nil -//} -// -//func (e *Engine) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, ip4 netip.Addr, ip6 netip.Addr) error { -// m := e.bindLocal() -// e.localMapping[peer] = m -// -// if err := e.wg.UpdatePeer(peer, PeerCfg{ -// Set: true, -// VIPs: &VirtualIPs{ -// IPv4: ip4, -// IPv6: ip6, -// }, -// KeepAliveInterval: nil, -// LocalEndpointPort: &m.port, -// }); err != nil { -// return fmt.Errorf("failed to update wireguard: %w", err) -// } -// -// if err := e.stage.AddPeer(peer, homeRelay, endpoints, session, ip4, ip6); err != nil { -// return fmt.Errorf("failed to update stage: %w", err) -// } -// return nil -//} -// -//func (e *Engine) UpdatePeer(peer key.NodePublic, homeRelay *int64, endpoints []netip.AddrPort, session *key.SessionPublic) error { -// return e.stage.UpdatePeer(peer, homeRelay, endpoints, session) -//} -// -//func (e *Engine) RemovePeer(peer key.NodePublic) error { -// if err := e.stage.RemovePeer(peer); err != nil { -// return err -// } -// -// if err := e.wg.RemovePeer(peer); err != nil { -// return fmt.Errorf("failed to remove peer from wireguard: %w", err) -// } -// -// return nil -//} -// -//func (e *Engine) UpdateRelays(relay []relay.Information) error { -// return e.stage.UpdateRelays(relay) -//} - -type FakeControl struct { - controlKey key.ControlPublic - ipv4 netip.Prefix - ipv6 netip.Prefix -} - -func (f *FakeControl) ControlKey() key.ControlPublic { - return f.controlKey -} - -func (f *FakeControl) IPv4() netip.Prefix { - return f.ipv4 -} - -func (f *FakeControl) IPv6() netip.Prefix { - return f.ipv6 -} - -func (f *FakeControl) InstallCallbacks(callbacks ifaces.ControlCallbacks) error { - // NOP - return nil -} + e.deviceKey = &key -func (f *FakeControl) UpdateEndpoints(ports []netip.AddrPort) error { - // NOP return nil } diff --git a/toversok/events.go b/toversok/events.go deleted file mode 100644 index f4f72ed..0000000 --- a/toversok/events.go +++ /dev/null @@ -1,59 +0,0 @@ -package toversok - -import ( - "github.com/LukaGiorgadze/gonull" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/relay" - "net/netip" -) - -// TODO DEPRECATED, should be refactored into using fake control client and such - -type Event interface { - EventName() string -} - -type RelayUpdate struct { - // Updates relays referenced in this set. - // - // Note: Deliberately does not allow for unsetting relays. - Set []relay.Information -} - -func (r RelayUpdate) EventName() string { - return "RelayUpdate" -} - -type PeerAddition struct { - Key key.NodePublic - - HomeRelayID int64 - SessionKey key.SessionPublic - Endpoints []netip.AddrPort - - VIPs VirtualIPs -} - -func (p PeerAddition) EventName() string { - return "PeerAddition" -} - -type PeerUpdate struct { - Key key.NodePublic - - HomeRelayID gonull.Nullable[int64] - SessionKey gonull.Nullable[key.SessionPublic] - Endpoints gonull.Nullable[[]netip.AddrPort] -} - -func (p PeerUpdate) EventName() string { - return "PeerUpdate" -} - -type PeerRemoval struct { - Key key.NodePublic -} - -func (p PeerRemoval) EventName() string { - return "PeerRemoval" -} diff --git a/toversok/fakecontrol.go b/toversok/fakecontrol.go new file mode 100644 index 0000000..501bc23 --- /dev/null +++ b/toversok/fakecontrol.go @@ -0,0 +1,50 @@ +package toversok + +import ( + "context" + "net/netip" + "time" + + "github.com/edup2p/common/types/ifaces" + "github.com/edup2p/common/types/key" +) + +type FakeControl struct { + controlKey key.ControlPublic + ipv4 netip.Prefix + ipv6 netip.Prefix +} + +func (f *FakeControl) ControlKey() key.ControlPublic { + return f.controlKey +} + +func (f *FakeControl) IPv4() netip.Prefix { + return f.ipv4 +} + +func (f *FakeControl) IPv6() netip.Prefix { + return f.ipv6 +} + +func (f *FakeControl) Expiry() time.Time { + return time.Time{} +} + +func (f *FakeControl) Context() context.Context { + return context.Background() +} + +func (f *FakeControl) InstallCallbacks(ifaces.ControlCallbacks) { + // NOP +} + +func (f *FakeControl) UpdateEndpoints([]netip.AddrPort) error { + // NOP + return nil +} + +func (f *FakeControl) UpdateHomeRelay(int64) error { + // NOP + return nil +} diff --git a/toversok/interface.go b/toversok/interface.go index 3c0e260..df61aa0 100644 --- a/toversok/interface.go +++ b/toversok/interface.go @@ -2,11 +2,13 @@ package toversok import ( "context" + "net" + "net/netip" + "time" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" - "net/netip" - "time" ) // PeerCfg isa a peer config update struct, all values are nullable through being pointers. @@ -72,6 +74,8 @@ type WireGuardController interface { // // Can possibly return nil, when the peer has been removed, or not yet known to the controller. ConnFor(node key.NodePublic) types.UDPConn + + GetInterface() *net.Interface } type FirewallHost interface { diff --git a/toversok/observer.go b/toversok/observer.go index 8a4c828..51529b1 100644 --- a/toversok/observer.go +++ b/toversok/observer.go @@ -1,13 +1,15 @@ package toversok +import "time" + // Observer functions as a state observer for the Engine, effectively allowing calling clients to peek into the engine state in an abstracted way. type Observer interface { RegisterStateChangeListener(func(state EngineState)) CurrentState() EngineState - GetNeedsLoginState() (url string, err error) - GetEstablishedState() // TODO + GetNeedsLoginState() (url string, deviceKeyCh chan<- string, err error) + GetEstablishedState() (expiry time.Time, err error) // TODO add ipv4,ipv6? } type EngineState byte diff --git a/toversok/session.go b/toversok/session.go index 62ac918..f3d7dad 100644 --- a/toversok/session.go +++ b/toversok/session.go @@ -2,16 +2,18 @@ package toversok import ( "context" + "errors" "fmt" + "log/slog" + "net/netip" + "sync" + "github.com/edup2p/common/toversok/actors" "github.com/edup2p/common/types" "github.com/edup2p/common/types/ifaces" "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgcontrol" "github.com/edup2p/common/types/relay" - "log/slog" - "net/netip" - "sync" ) // Session represents one single session; a session key is generated here, and used inside a Stage @@ -21,13 +23,12 @@ type Session struct { wg WireGuardController fw FirewallController + cs ifaces.ControlSession quarantineMu sync.Mutex quarantinedPeers map[key.NodePublic]bool peerAddrs map[key.NodePublic][]netip.Addr - //control ifaces.ControlSession - stage ifaces.Stage sessionKey key.SessionPrivate @@ -43,8 +44,7 @@ func SetupSession( logon types.LogonCallback, ) (*Session, error) { ctx, ccc := context.WithCancelCause(engineCtx) - - sCtx := context.WithValue(ctx, "ccc", ccc) + sCtx := context.WithValue(ctx, types.CCC, ccc) sess := &Session{ ctx: sCtx, @@ -57,13 +57,14 @@ func SetupSession( stage: nil, } - cc, err := co.CreateClient(sess.ctx, getNodePriv, sess.getPriv, logon) + var err error + sess.cs, err = co.CreateClient(sess.ctx, getNodePriv, sess.getPriv, logon) if err != nil { sess.ccc(err) return nil, fmt.Errorf("could not create control client: %w", err) } - if sess.wg, err = wg.Controller(*getNodePriv(), cc.IPv4(), cc.IPv6()); err != nil { + if sess.wg, err = wg.Controller(*getNodePriv(), sess.cs.IPv4(), sess.cs.IPv6()); err != nil { err = fmt.Errorf("could not init wireguard: %w", err) sess.ccc(err) return nil, err @@ -75,9 +76,24 @@ func SetupSession( return nil, err } - sess.stage = actors.MakeStage(sess.ctx, getNodePriv, sess.getPriv, getExtSock, sess.wg.ConnFor, cc) - - cc.InstallCallbacks(sess) + sess.stage = actors.MakeStage( + sess.ctx, + getNodePriv, + sess.getPriv, + getExtSock, + sess.wg.ConnFor, + sess.cs, + nil, + sess.wg.GetInterface(), + ) + + sess.cs.InstallCallbacks(sess) + context.AfterFunc(sess.cs.Context(), func() { + sess.ccc(errors.New("resumable control session exited")) + }) + context.AfterFunc(sess.stage.Context(), func() { + sess.ccc(errors.New("stage exited")) + }) return sess, nil } @@ -134,7 +150,7 @@ func (s *Session) triggerQuarantineUpdate() { // CONTROL CALLBACKS -func (s *Session) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, ip4 netip.Addr, ip6 netip.Addr, prop msgcontrol.Properties) error { +func (s *Session) AddPeer(peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, ip4, ip6 netip.Addr, prop msgcontrol.Properties) error { s.registerPeerAddrs(peer, ip4, ip6) if prop.Quarantine { diff --git a/types/control/client.go b/types/control/client.go index b05711f..7e29a75 100644 --- a/types/control/client.go +++ b/types/control/client.go @@ -5,16 +5,18 @@ import ( "context" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgcontrol" "log/slog" "net/netip" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgcontrol" ) type Client struct { ctx context.Context + ccc context.CancelCauseFunc cc *Conn @@ -28,14 +30,17 @@ type Client struct { IPv4 netip.Prefix IPv6 netip.Prefix - // TODO + Expiry time.Time } func EstablishClient(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, timeout time.Duration, getPriv func() *key.NodePrivate, getSess func() *key.SessionPrivate, controlKey key.ControlPublic, session *string, logon types.LogonCallback) (*Client, error) { + ctx, ccc := context.WithCancelCause(parentCtx) + c := &Client{ - ctx: parentCtx, + ctx: ctx, + ccc: ccc, - cc: NewConn(parentCtx, mc, brw), + cc: NewConn(ctx, mc, brw), getPriv: getPriv, getSess: getSess, @@ -46,21 +51,14 @@ func EstablishClient(parentCtx context.Context, mc types.MetaConn, brw *bufio.Re if err := c.Handshake(timeout, logon); err != nil { return nil, err - } else { - return c, nil } -} -func (c *Client) Handshake(timeout time.Duration, logon types.LogonCallback) error { + context.AfterFunc(c.ctx, c.Close) - // TODO - // 1. send ClientHello - // 2. expect ServerHello - // 3. send Logon - // 4. (optional) expect LogonAuthenticate - // - Allow sending LogonDeviceKey - // 4. expect LogonAccept|LogonReject + return c, nil +} +func (c *Client) Handshake(timeout time.Duration, logon types.LogonCallback) error { if timeout != 0 { if err := c.cc.mc.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("can't set deadline: %w", err) @@ -85,10 +83,8 @@ func (c *Client) Handshake(timeout time.Duration, logon types.LogonCallback) err if c.ControlKey.IsZero() { c.ControlKey = serverHello.ControlNodePub // TODO log TOFU? - } else { - if serverHello.ControlNodePub != c.ControlKey { - return fmt.Errorf("client-stated control key does not match server-given control key") - } + } else if serverHello.ControlNodePub != c.ControlKey { + return fmt.Errorf("client-stated control key does not match server-given control key") } clearData, ok := c.getPriv().OpenFromControl(c.ControlKey, serverHello.CheckData) @@ -133,81 +129,25 @@ func (c *Client) Handshake(timeout time.Duration, logon types.LogonCallback) err case *msgcontrol.LogonAccept: c.SessionID = &m.SessionID - //c.IPv4 = netip.PrefixFrom(netip.Addr(m.IPv4Addr), int(m.IPv4Mask)) - //c.IPv6 = netip.PrefixFrom(netip.Addr(m.IPv6Addr), int(m.IPv6Mask)) - c.IPv4 = m.IP4 c.IPv6 = m.IP6 + c.Expiry = m.AuthExpiry + slog.Debug("logon accepted", "as-peer", nodePubKey.Debug(), "as-sess", sessPubKey.Debug(), "with-sess-id", types.PtrOr(c.SessionID, ""), "with-ipv4", c.IPv4.String(), "with-ipv6", c.IPv6.String()) return nil default: return fmt.Errorf("received unknown message type after-logon: %d", m) } - - //switch typ { - //case msgcontrol.LogonAuthenticateType: - // // TODO - // panic("authenticate logic not implemented") - //case msgcontrol.LogonAcceptType: - // accept := new(msgcontrol.LogonAccept) - // if err := ReadMessage(c.reader, msgLen, accept); err != nil { - // return fmt.Errorf("error when reading after-logon reject message: %w", err) - // } - // - // c.SessionID = &accept.SessionID - // c.IPv4 = netip.PrefixFrom(netip.Addr(accept.IPv4Addr), int(accept.IPv4Mask)) - // c.IPv6 = netip.PrefixFrom(netip.Addr(accept.IPv6Addr), int(accept.IPv6Mask)) - // - // return nil - // - //case msgcontrol.LogonRejectType: - // reject := new(msgcontrol.LogonReject) - // if err := ReadMessage(c.reader, msgLen, reject); err != nil { - // return fmt.Errorf("error when reading after-logon reject message: %w", err) - // } - // - // return fmt.Errorf( - // "logon rejected after-logon: %s; retry strategy: %w", - // reject.Reason, - // types.PtrOr(reject.RetryStrategy, msgcontrol.NoRetryStrategy), - // ) - //default: - // return fmt.Errorf("received unknown message type after-logon: %d", typ) - //} - // - //typ, msgLen, err = ReadMessageHeader(c.reader) - //if err != nil { - // return fmt.Errorf("error when receiving after-authenticate message: %w", err) - //} - // - //switch typ { - //case msgcontrol.LogonAcceptType: - // // TODO - // panic("implement me") - //case msgcontrol.LogonRejectType: - // reject := new(msgcontrol.LogonReject) - // if err := ReadMessage(c.reader, msgLen, reject); err != nil { - // return fmt.Errorf("error when reading after-authenticate reject message: %w", err) - // } - // - // return fmt.Errorf( - // "logon rejected after-authenticate: %s; retry strategy: %w", - // reject.Reason, - // types.PtrOr(reject.RetryStrategy, msgcontrol.NoRetryStrategy), - // ) - //default: - // return fmt.Errorf("received unknown message type after-authenticate: %d", typ) - //} } -var NeedsLogonError = errors.New("needs logon callback") +var ErrNeedsLogon = errors.New("needs logon callback") func (c *Client) handleLogon(url string, logon types.LogonCallback) (msgcontrol.ControlMessage, error) { if logon == nil { // No way we can start or create a logon session, abort - return nil, fmt.Errorf("logonauthenticate requested when no interactive logon callback exists, aborting; %w", NeedsLogonError) + return nil, fmt.Errorf("logonauthenticate requested when no interactive logon callback exists, aborting; %w", ErrNeedsLogon) } deviceKeyChan := make(chan string) @@ -230,7 +170,6 @@ func (c *Client) handleLogon(url string, logon types.LogonCallback) (msgcontrol. } }() - // TODO also add context error / close here select { case deviceKey := <-deviceKeyChan: if err := c.cc.Write(&msgcontrol.LogonDeviceKey{ @@ -257,11 +196,11 @@ func (c *Client) handleLogon(url string, logon types.LogonCallback) (msgcontrol. } } -var ClosedErr = errors.New("client closed") +var ErrClosed = errors.New("client closed") func (c *Client) Send(msg msgcontrol.ControlMessage) error { if types.IsContextDone(c.ctx) { - return ClosedErr + return ErrClosed } return c.cc.Write(msg) @@ -270,7 +209,7 @@ func (c *Client) Send(msg msgcontrol.ControlMessage) error { // Recv blocks until it receives a package, it will return (nil, nil) if timeout func (c *Client) Recv(ttfbTimeout time.Duration) (msgcontrol.ControlMessage, error) { if types.IsContextDone(c.ctx) { - return nil, ClosedErr + return nil, ErrClosed } return c.cc.Read(ttfbTimeout) @@ -281,5 +220,11 @@ func (c *Client) Recv(ttfbTimeout time.Duration) (msgcontrol.ControlMessage, err // } func (c *Client) Close() { - c.cc.mc.Close() + if err := c.cc.mc.Close(); err != nil { + slog.Error("error when closing control client", "err", err) + } +} + +func (c *Client) Cancel(err error) { + c.ccc(err) } diff --git a/types/control/conn.go b/types/control/conn.go index 0f3b036..b77b7e6 100644 --- a/types/control/conn.go +++ b/types/control/conn.go @@ -6,13 +6,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/msgcontrol" "io" "log/slog" "os" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/msgcontrol" ) type Conn struct { @@ -37,7 +38,6 @@ func NewConn(ctx context.Context, mc types.MetaConn, brw *bufio.ReadWriter) *Con } func (c *Conn) UnmarshalInto(data []byte, to msgcontrol.ControlMessage) error { - if err := json.Unmarshal(data, to); err != nil { return fmt.Errorf("failed to unmarshal data: %w", err) } @@ -50,7 +50,6 @@ func (c *Conn) Expect(to msgcontrol.ControlMessage, ttfbTimeout time.Duration) e defer c.readMutex.Unlock() msgTyp, data, err := c.readRawMessageLocked(ttfbTimeout) - if err != nil { return fmt.Errorf("failed reading message: %w", err) } @@ -95,8 +94,6 @@ func (c *Conn) Read(ttfbTimeout time.Duration) (msgcontrol.ControlMessage, error to = new(msgcontrol.LogonAccept) case msgcontrol.LogonRejectType: to = new(msgcontrol.LogonReject) - case msgcontrol.LogoutType: - to = new(msgcontrol.Logout) case msgcontrol.PingType: to = new(msgcontrol.Ping) case msgcontrol.PongType: @@ -114,8 +111,13 @@ func (c *Conn) Read(ttfbTimeout time.Duration) (msgcontrol.ControlMessage, error to = new(msgcontrol.PeerRemove) case msgcontrol.RelayUpdateType: to = new(msgcontrol.RelayUpdate) + case msgcontrol.LogoutType: + to = new(msgcontrol.Logout) + case msgcontrol.DisconnectType: + to = new(msgcontrol.Disconnect) + default: - panic(fmt.Sprintf("Unknown type %v", typ)) + return nil, fmt.Errorf("unknown type %v", typ) } if err = c.UnmarshalInto(data, to); err != nil { @@ -181,10 +183,6 @@ func (c *Conn) Write(obj msgcontrol.ControlMessage) error { c.writeMutex.Lock() defer c.writeMutex.Unlock() - //// FIXME: bson is extremely fucky and will write empty values if it cannot decode something, so be careful with that - //// or replace this with a registry thingie. - //data, err := bson.Marshal(obj) - data, err := json.Marshal(obj) if err != nil { return fmt.Errorf("could not marshal data: %w", err) diff --git a/types/control/controlhttp/http_client.go b/types/control/controlhttp/http_client.go index 7f6a3c2..b1e36c9 100644 --- a/types/control/controlhttp/http_client.go +++ b/types/control/controlhttp/http_client.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/control" "github.com/edup2p/common/types/dial" @@ -26,6 +27,6 @@ func Dial(ctx context.Context, opts dial.Opts, getPriv func() *key.NodePrivate, opts.SetDefaults() return dial.HTTP(ctx, opts, makeControlURL(opts), control.UpgradeProtocol, func(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, opts dial.Opts) (*control.Client, error) { - return control.EstablishClient(ctx, mc, brw, opts.EstablishTimeout, getPriv, getSess, controlKey, session, logon) + return control.EstablishClient(parentCtx, mc, brw, opts.EstablishTimeout, getPriv, getSess, controlKey, session, logon) }) } diff --git a/types/control/controlhttp/http_server.go b/types/control/controlhttp/http_server.go index 833f12c..9930bc3 100644 --- a/types/control/controlhttp/http_server.go +++ b/types/control/controlhttp/http_server.go @@ -1,9 +1,10 @@ package controlhttp import ( + "net/http" + "github.com/edup2p/common/types/control" "github.com/edup2p/common/types/dial" - "net/http" ) func ServerHandler(s *control.Server) http.Handler { diff --git a/types/control/graph.go b/types/control/graph.go index 404e266..821378a 100644 --- a/types/control/graph.go +++ b/types/control/graph.go @@ -2,9 +2,10 @@ package control import ( "errors" + "sync" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgcontrol" - "sync" ) type EdgeGraph struct { @@ -133,23 +134,23 @@ func (g *EdgeGraph) GetEdges(node ClientID) map[ClientID]VisibilityPair { return targetMap } -func (g *EdgeGraph) GetEdge(from, to ClientID) *VisibilityPair { +func (g *EdgeGraph) GetEdge(from, to ClientID) (retPair *VisibilityPair) { g.mu.RLock() defer g.mu.RUnlock() targetMap := g.graph[from] if targetMap == nil { - return nil + return } pair := targetMap[to] if pair != nil { - pair = &(*pair) + *retPair = *pair } - return pair + return } type VisibilityPair struct { diff --git a/types/control/iface.go b/types/control/iface.go index edbfc9c..85c52b4 100644 --- a/types/control/iface.go +++ b/types/control/iface.go @@ -2,16 +2,23 @@ package control import ( "errors" - "github.com/edup2p/common/types/key" "net/netip" -) + "time" -type ClientID key.NodePublic -type SessID string + "github.com/edup2p/common/types/key" +) -var SessionDoesNotExistError = errors.New("session does not exist") +type ( + ClientID key.NodePublic + SessID string +) -var SessionIsNotAuthenticating = errors.New("session is not authenticating") +var ( + ErrSessionDoesNotExist = errors.New("session does not exist") + ErrSessionIsNotAuthenticating = errors.New("session is not authenticating") + ErrNeedsDisconnect = errors.New("session needs disconnect") + ErrClientNotConnected = errors.New("client is not connected") +) // ServerLogic denotes exposed functions that a control server must provide for any business logic to interface with it. type ServerLogic interface { @@ -33,6 +40,7 @@ type ServerLogic interface { SendAuthURL(id SessID, url string) error // AcceptAuthentication will accept the pending authentication of the indicated session ID. // Must be called, or RejectAuthentication must be called. + // Second time argument dictates for how long the // Will error if the session is not pending authentication. AcceptAuthentication(SessID) error // RejectAuthentication will reject the pending authentication of the indicated session ID. @@ -40,6 +48,13 @@ type ServerLogic interface { // Will error if the session is not pending authentication. RejectAuthentication(id SessID, reason string) error + // DisconnectSession will disconnect a running client session (and invalidate its ID), if it exists. + // Will error if session does not exist. + DisconnectSession(id SessID) error + // DisconnectClient will disconnect a running session per client (and invalidate its ID), if its connected. + // Will error if client is not connected. + DisconnectClient(id ClientID) error + /// The following functions pertain to client-client networking visibility. // GetVisibilityPairs gets all pairs of a particular ClientID. @@ -72,8 +87,9 @@ type ServerCallbacks interface { OnDeviceKey(id SessID, key string) // OnSessionFinalize is called right after ServerLogic.AcceptAuthentication, but before that message is sent to the client. - // The client needs to known which virtual IPs it can use, and this function will provide it to the control server. - OnSessionFinalize(SessID, ClientID) (netip.Prefix, netip.Prefix) + // The client needs to known which virtual IPs it can use, and the expiry time of the authentication, + // and this function will provide it to the control server. + OnSessionFinalize(SessID, ClientID) (netip.Prefix, netip.Prefix, time.Time) // OnSessionDestroy is called after the client has been disconnected. OnSessionDestroy(SessID, ClientID) diff --git a/types/control/logic.go b/types/control/logic.go index b588ec0..5de4219 100644 --- a/types/control/logic.go +++ b/types/control/logic.go @@ -2,6 +2,7 @@ package control import ( "errors" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgcontrol" ) @@ -21,11 +22,11 @@ func (s *Server) whenSessAuthenticating(id SessID, f func(*ServerSession) error) sess, ok := s.sessByID[sid] if !ok { - return SessionDoesNotExistError + return ErrSessionDoesNotExist } if sess.state != Authenticate { - return SessionIsNotAuthenticating + return ErrSessionIsNotAuthenticating } return f(sess) @@ -33,7 +34,7 @@ func (s *Server) whenSessAuthenticating(id SessID, f func(*ServerSession) error) func (s *Server) SendAuthURL(id SessID, url string) error { return s.whenSessAuthenticating(id, func(sess *ServerSession) error { - sess.authChan <- AuthUrl{url: url} + sess.authChan <- AuthURL{url: url} return nil }) @@ -68,7 +69,7 @@ func (s *Server) GetClientID(id SessID) (ClientID, error) { sess, ok := s.sessByID[sid] if !ok { - return nilClientID, SessionDoesNotExistError + return nilClientID, ErrSessionDoesNotExist } return ClientID(sess.Peer), nil @@ -85,11 +86,9 @@ func (s *Server) GetConnectedClients() (map[SessID]ClientID, error) { } return retMap, nil - - // todo what do we use the error field for here? } -func (s *Server) UpsertVisibilityPair(id ClientID, id2 ClientID, pair VisibilityPair) error { +func (s *Server) UpsertVisibilityPair(id, id2 ClientID, pair VisibilityPair) error { s.sessLock.RLock() defer s.sessLock.RUnlock() @@ -170,7 +169,7 @@ func (s *Server) UpsertMultiVisibilityPair(id ClientID, m map[ClientID]Visibilit return nil } -func (s *Server) RemoveVisibilityPair(from ClientID, to ClientID) error { +func (s *Server) RemoveVisibilityPair(from, to ClientID) error { s.sessLock.RLock() defer s.sessLock.RUnlock() @@ -218,6 +217,37 @@ func (s *Server) GetVisibilityPairs(id ClientID) (map[ClientID]VisibilityPair, e return pairs, nil } +func (s *Server) DisconnectSession(id SessID) error { + s.sessLock.RLock() + defer s.sessLock.RUnlock() + + sess, ok := s.sessByID[string(id)] + + if !ok { + return ErrSessionDoesNotExist + } + + sess.Ccc(ErrNeedsDisconnect) + + return nil +} + +func (s *Server) DisconnectClient(id ClientID) error { + s.sessLock.RLock() + defer s.sessLock.RUnlock() + + sess, ok := s.sessByNode[key.NodePublic(id)] + + if !ok { + return ErrClientNotConnected + } + + sess.Ccc(ErrNeedsDisconnect) + + return nil +} + +//nolint:unused func (s *Server) atomicDoVisibilityPairs(id key.NodePublic, f func(map[ClientID]VisibilityPair) error) error { s.sessLock.RLock() defer s.sessLock.RUnlock() diff --git a/types/control/server.go b/types/control/server.go index 17abd1b..0440b6f 100644 --- a/types/control/server.go +++ b/types/control/server.go @@ -6,17 +6,18 @@ import ( "crypto/rand" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgcontrol" - "github.com/edup2p/common/types/relay" - stunserver "github.com/edup2p/common/types/stun" "log/slog" "net" "net/netip" "slices" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgcontrol" + "github.com/edup2p/common/types/relay" + stunserver "github.com/edup2p/common/types/stun" ) type Server struct { @@ -139,21 +140,19 @@ func (s *Server) Logger() *slog.Logger { return slog.With("control", s.privKey.Public().Debug()) } -func (s *Server) Accept(ctx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, remoteAddrPort netip.AddrPort) error { +func (s *Server) Accept(ctx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, _ netip.AddrPort) error { cc := NewConn(ctx, mc, brw) // TODO this logon segment can be in a different function { // TODO set deadline on read - err, clientHello, logon := s.handleLogon(cc) - + clientHello, logon, err := s.handleLogon(cc) if err != nil { return fmt.Errorf("handle logon: %w", err) } sess, resumed, err := s.ReEstablishOrMakeSession(cc, clientHello.ClientNodePub, logon.SessKey, logon.ResumeSessionID) - if err != nil { return s.doReject(cc, sess, err) } @@ -168,7 +167,6 @@ func (s *Server) Accept(ctx context.Context, mc types.MetaConn, brw *bufio.ReadW if err = sess.AuthAndStart(); err != nil { return err } - //go sess.Run() } // Wait until connection dead @@ -178,14 +176,6 @@ func (s *Server) Accept(ctx context.Context, mc types.MetaConn, brw *bufio.ReadW return sess.Ctx.Err() - //// for now, send a reject - //if err := cc.Write(&msgcontrol.LogonReject{ - // Reason: "dev: reject unambiguously", - // RetryStrategy: 0, - //}); err != nil { - // return fmt.Errorf("error when sending reject: %w", err) - //} - // TODO send authenticate (then wait, or expect devicekey), accept, or reject // TODO resume @@ -204,17 +194,14 @@ func (s *Server) Accept(ctx context.Context, mc types.MetaConn, brw *bufio.ReadW // TODO (here) mark session as latent } - - //TODO implement me - panic("implement me") } -func (s *Server) handleLogon(cc *Conn) (error, *msgcontrol.ClientHello, *msgcontrol.Logon) { +func (s *Server) handleLogon(cc *Conn) (*msgcontrol.ClientHello, *msgcontrol.Logon, error) { // TODO set deadline on read - var clientHello = new(msgcontrol.ClientHello) + clientHello := new(msgcontrol.ClientHello) if err := cc.Expect(clientHello, HandshakeReceiveTimeout); err != nil { - return fmt.Errorf("error when receiving clienthello: %w", err), nil, nil + return nil, nil, fmt.Errorf("error when receiving clienthello: %w", err) } data := randData() @@ -223,12 +210,12 @@ func (s *Server) handleLogon(cc *Conn) (error, *msgcontrol.ClientHello, *msgcont ControlNodePub: s.privKey.Public(), CheckData: s.privKey.SealToNode(clientHello.ClientNodePub, data), }); err != nil { - return fmt.Errorf("error when sending serverhello: %w", err), nil, nil + return nil, nil, fmt.Errorf("error when sending serverhello: %w", err) } logon := new(msgcontrol.Logon) if err := cc.Expect(logon, HandshakeReceiveTimeout); err != nil { - return fmt.Errorf("error when receiving logon: %w", err), nil, nil + return nil, nil, fmt.Errorf("error when receiving logon: %w", err) } // Verify logon @@ -237,46 +224,40 @@ func (s *Server) handleLogon(cc *Conn) (error, *msgcontrol.ClientHello, *msgcont var ok bool if nodeData, ok = s.privKey.OpenFromNode(clientHello.ClientNodePub, logon.NodeKeyAttestation); !ok { - return fmt.Errorf("could not open node attestation"), nil, nil + return nil, nil, fmt.Errorf("could not open node attestation") } if sessData, ok = s.privKey.OpenFromSession(logon.SessKey, logon.SessKeyAttestation); !ok { - return fmt.Errorf("could not open session attestation"), nil, nil + return nil, nil, fmt.Errorf("could not open session attestation") } // FIXME: we should probably make the below something like constant time, to prevent timing attacks. // It is not now, for development purposes. if !slices.Equal(data, nodeData) { - return fmt.Errorf("node data not equal"), nil, nil + return nil, nil, fmt.Errorf("node data not equal") } if !slices.Equal(data, sessData) { - return fmt.Errorf("sess data not equal"), nil, nil + return nil, nil, fmt.Errorf("sess data not equal") } } - return nil, clientHello, logon + return clientHello, logon, nil } func (s *Server) doReject(cc *Conn, sess *ServerSession, err error) error { - reject := &msgcontrol.LogonReject{} - if errors.Is(err, stillEstablished) { - if errors.Is(err, stillEstablished) { - // TODO we need to replace this with knocking-and-acquiring - - reject.RetryStrategy = msgcontrol.RegenerateSessionKey - reject.RetryAfter = time.Second * 15 - reject.Reason = "other client session still active, please retry" - } else { - reject.Reason = "cannot log in at the moment, please retry in the future" - } - } else if errors.Is(err, sessionIdMismatch) { + switch { + case errors.Is(err, errStillEstablished): + reject.RetryStrategy = msgcontrol.RegenerateSessionKey + reject.RetryAfter = time.Second * 15 + reject.Reason = "other client session still active, please retry" + case errors.Is(err, errSessionIDMismatch): reject.RetryStrategy = msgcontrol.RecreateSession reject.Reason = "session ID mismatch, please try without" - } else { + default: reject.Reason = "could not acquire session" slog.Warn("rejected session with unknown error", "err", err) } @@ -292,15 +273,16 @@ func (s *Server) doReject(cc *Conn, sess *ServerSession, err error) error { if err := cc.Write(reject); err != nil { return fmt.Errorf("error when sending reject: %w", err) - } else { - return nil } + + return nil } func randData() []byte { b := make([]byte, 32) _, err := rand.Read(b) if err != nil { + // We expect the system random to at least be accessible now panic(fmt.Errorf("could not read rand: %w", err)) } return b @@ -318,7 +300,7 @@ func NewServer(privKey key.ControlPrivate, relays []relay.Information) *Server { sessLock: sync.RWMutex{}, sessByNode: make(map[key.NodePublic]*ServerSession), sessByID: make(map[string]*ServerSession), - //getIPs: getIPs, + // getIPs: getIPs, relays: relays, vGraph: NewEdgeGraph(), pendingLock: sync.Mutex{}, @@ -353,8 +335,16 @@ func (s *Server) RunAdditionalSTUN(publicIPs []netip.Addr, listenHost string, lo s.stun.lowServer = stunserver.NewServer(s.ctx) s.stun.highServer = stunserver.NewServer(s.ctx) - go s.stun.lowServer.ListenAndServe(lowAp) - go s.stun.highServer.ListenAndServe(highAp) + go func() { + if err := s.stun.lowServer.ListenAndServe(lowAp); err != nil { + slog.Error("low stun server ListenAndServe error", "err", err) + } + }() + go func() { + if err := s.stun.highServer.ListenAndServe(highAp); err != nil { + slog.Error("high stun server ListenAndServe error", "err", err) + } + }() t := true @@ -398,23 +388,23 @@ func (s *Server) relayExists(id int64) bool { } var ( - incorrectState = errors.New("incorrect state, want nil or Dangling") - stillEstablished = errors.New("session is still established or reestablished") - sessionIdMismatch = errors.New("session ID did not match") + errIncorrectState = errors.New("incorrect state, want nil or Dangling") + errStillEstablished = errors.New("session is still established or reestablished") + errSessionIDMismatch = errors.New("session ID did not match") ) -func (s *Server) ReEstablishOrMakeSession(cc *Conn, nodeKey key.NodePublic, sessKey key.SessionPublic, sessId *string) (retSess *ServerSession, resumed bool, err error) { +func (s *Server) ReEstablishOrMakeSession(cc *Conn, nodeKey key.NodePublic, sessKey key.SessionPublic, sessID *string) (retSess *ServerSession, resumed bool, err error) { s.sessLock.Lock() defer s.sessLock.Unlock() sess, ok := s.sessByNode[nodeKey] if !ok { - if sessId != nil { + if sessID != nil { // There's no session ID to match if its empty. // The client requested resume, so we need to tell it to try again without the session ID, // kicking internal logic to regenerate session keys and clearing state. - err = sessionIdMismatch + err = errSessionIDMismatch return } @@ -433,13 +423,13 @@ func (s *Server) ReEstablishOrMakeSession(cc *Conn, nodeKey key.NodePublic, sess // less simple path: we have a session in state for this nodekey if sess.state != Dangling { // We only accept resuming dangling sessions, everything else is incorrect. - err = incorrectState + err = errIncorrectState if sess.state == Established || sess.state == ReEstablishing { // The server may lag behind for a second, so if we wrap this error and return the session, // the caller could knock that session to force it to Dangling. - err = fmt.Errorf("established state (%w): %w", err, stillEstablished) + err = fmt.Errorf("established state (%w): %w", err, errStillEstablished) retSess = sess } @@ -447,10 +437,10 @@ func (s *Server) ReEstablishOrMakeSession(cc *Conn, nodeKey key.NodePublic, sess } // Session is dangling, we can grab it - if sessId != nil && sess.ID != *sessId { + if sessID != nil && sess.ID != *sessID { // Cant resume, the client expects a different session ID - err = sessionIdMismatch + err = errSessionIDMismatch return } @@ -502,14 +492,9 @@ func (s *Server) RemoveSession(sess *ServerSession) { return nil }) - if err != nil { slog.Error("failed to remove sessions", "err", err) } - - //s.ForVisibleLocked(sess, func(session *ServerSession) { - // session.Bye(sess.Peer) - //}) } slog.Info("REMOVE session", "peer", sess.Peer.Debug()) @@ -518,7 +503,7 @@ func (s *Server) RemoveSession(sess *ServerSession) { delete(s.sessByID, sess.ID) } -//func (s *Server) RegisterSession(sess *ServerSession) { +// func (s *Server) RegisterSession(sess *ServerSession) { // // TODO resume support // // s.sessLock.Lock() @@ -531,7 +516,7 @@ func (s *Server) RemoveSession(sess *ServerSession) { // // s.sessByNode[sess.Peer] = sess // s.sessByID[sess.ID] = sess -//} +// } // ForVisible is called by fromSess' Run goroutine, to inform all other sessions it can see of a change (and the likes) func (s *Server) ForVisible(fromSess *ServerSession, f func(session *ServerSession)) { diff --git a/types/control/server_session.go b/types/control/server_session.go index 9924174..9d51f66 100644 --- a/types/control/server_session.go +++ b/types/control/server_session.go @@ -4,14 +4,15 @@ import ( "context" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgcontrol" "log/slog" "net/netip" "os" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgcontrol" ) type ServerSession struct { @@ -44,6 +45,8 @@ type ServerSession struct { server *Server + Expiry time.Time + // TODO // all synced state, known changes, queued changes, etc. } @@ -92,7 +95,6 @@ func (s *ServerSession) doAuthenticate(resumed bool) error { for ctx.Err() == nil { msg := msgcontrol.LogonDeviceKey{} err := s.conn.Expect(&msg, time.Millisecond*100) - if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { continue @@ -102,18 +104,16 @@ func (s *ServerSession) doAuthenticate(resumed bool) error { errChan <- err - return - } else { - msgChan <- msg - return } + + msgChan <- msg } }() wg.Add(1) deviceKeySeen := false - authUrlSent := false + authURLSent := false // TODO build timeout in here somewhere @@ -134,20 +134,19 @@ func (s *ServerSession) doAuthenticate(resumed bool) error { switch msg := authMsg.(type) { case RejectAuth: err := s.conn.Write(msg.LogonReject) - if err != nil { - return fmt.Errorf("error while writing logon reject: %w, %w", err, LogonRejectedError) + return fmt.Errorf("error while writing logon reject: %w, %w", err, ErrLogonRejected) } - return LogonRejectedError + return ErrLogonRejected case AcceptAuth: return nil - case AuthUrl: - if authUrlSent { + case AuthURL: + if authURLSent { // auth url already sent, this is a business logic error, we should error out return fmt.Errorf("business logic sent auth url twice") } - authUrlSent = true + authURLSent = true err := s.conn.Write(&msgcontrol.LogonAuthenticate{ AuthenticateURL: msg.url, @@ -179,11 +178,11 @@ type RejectAuth struct { type AcceptAuth struct{} -type AuthUrl struct { +type AuthURL struct { url string } -var LogonRejectedError = errors.New("authentication resulted in logon rejected") +var ErrLogonRejected = errors.New("authentication resulted in logon rejected") // Knock asks the session goroutine/connection to "knock" (send ping, await pong) the session, // to make sure it is still alive. @@ -198,7 +197,7 @@ func (s *ServerSession) Knock() (dangling bool) { func (s *ServerSession) Greet(otherSess *ServerSession, prop msgcontrol.Properties) { s.Slog().Debug("Greet", "from", otherSess.Peer.Debug()) - s.conn.Write(&msgcontrol.PeerAddition{ + if err := s.conn.Write(&msgcontrol.PeerAddition{ PubKey: otherSess.Peer, SessKey: otherSess.Sess, IPv4: otherSess.IPv4.Addr(), @@ -206,7 +205,9 @@ func (s *ServerSession) Greet(otherSess *ServerSession, prop msgcontrol.Properti Endpoints: otherSess.CurrentEndpoints, HomeRelay: otherSess.HomeRelay, Properties: prop, - }) + }); err != nil { + slog.Error("error writing peer addition", "err", err) + } s.greetedMu.Lock() defer s.greetedMu.Unlock() @@ -226,10 +227,12 @@ func (s *ServerSession) UpdateEndpoints(peer key.NodePublic, endpoints []netip.A s.Slog().Debug("UpdateEndpoints", "from", peer.Debug(), "endpoints", endpoints) - s.conn.Write(&msgcontrol.PeerUpdate{ + if err := s.conn.Write(&msgcontrol.PeerUpdate{ PubKey: peer, Endpoints: endpoints, - }) + }); err != nil { + slog.Error("error writing endpoints peer update", "err", err) + } } func (s *ServerSession) UpdateSessKey(peer key.NodePublic, sessKey key.SessionPublic) { @@ -237,10 +240,12 @@ func (s *ServerSession) UpdateSessKey(peer key.NodePublic, sessKey key.SessionPu s.Slog().Debug("UpdateSessKey", "from", peer.Debug(), "sess-key", sessKey) - s.conn.Write(&msgcontrol.PeerUpdate{ + if err := s.conn.Write(&msgcontrol.PeerUpdate{ PubKey: peer, SessKey: &sessKey, - }) + }); err != nil { + slog.Error("error writing sess key peer update", "err", err) + } } func (s *ServerSession) UpdateHomeRelay(peer key.NodePublic, homeRelay int64) { @@ -248,28 +253,34 @@ func (s *ServerSession) UpdateHomeRelay(peer key.NodePublic, homeRelay int64) { s.Slog().Debug("UpdateHomeRelay", "from", peer.Debug(), "home-relay", homeRelay) - s.conn.Write(&msgcontrol.PeerUpdate{ + if err := s.conn.Write(&msgcontrol.PeerUpdate{ PubKey: peer, HomeRelay: &homeRelay, - }) + }); err != nil { + slog.Error("error writing home relay peer update", "err", err) + } } func (s *ServerSession) UpdateProperties(peer key.NodePublic, prop msgcontrol.Properties) { s.Slog().Debug("UpdateProperties", "from", peer.Debug(), "prop", prop) - s.conn.Write(&msgcontrol.PeerUpdate{ + if err := s.conn.Write(&msgcontrol.PeerUpdate{ PubKey: peer, Properties: &prop, - }) + }); err != nil { + slog.Error("error writing properties peer update", "err", err) + } } // Bye to another session, send PeerRemove func (s *ServerSession) Bye(peer key.NodePublic) { s.Slog().Debug("Bye", "from", peer.Debug()) - s.conn.Write(&msgcontrol.PeerRemove{ + if err := s.conn.Write(&msgcontrol.PeerRemove{ PubKey: peer, - }) + }); err != nil { + slog.Error("error writing peer remove message", "err", err) + } } // SendRelays sends all relay information to the client. This is not ran on Resume. @@ -279,7 +290,7 @@ func (s *ServerSession) SendRelays() error { return s.conn.Write(&msgcontrol.RelayUpdate{Relays: s.server.relays}) } -func (s *ServerSession) Resume(cc *Conn, sessKey key.SessionPublic) { +func (s *ServerSession) Resume(_ *Conn, _ key.SessionPublic) { // TODO: check sessKey == s.key, else send sesskeyupdate // TODO we send nothing to the client except queued messages, which are backed up. @@ -294,9 +305,10 @@ func (s *ServerSession) AuthenticateAccept() (err error) { s.Slog().Debug("AuthenticateAccept") if err = s.conn.Write(&msgcontrol.LogonAccept{ - IP4: s.IPv4, - IP6: s.IPv6, - SessionID: s.ID, + IP4: s.IPv4, + IP6: s.IPv6, + AuthExpiry: s.Expiry, + SessionID: s.ID, }); err != nil { err = fmt.Errorf("error when sending accept: %w", err) return @@ -306,10 +318,9 @@ func (s *ServerSession) AuthenticateAccept() (err error) { } func (s *ServerSession) AuthAndStart() error { - s.IPv4, s.IPv6 = s.server.callbacks.OnSessionFinalize(SessID(s.ID), ClientID(s.Peer)) + s.IPv4, s.IPv6, s.Expiry = s.server.callbacks.OnSessionFinalize(SessID(s.ID), ClientID(s.Peer)) err := s.AuthenticateAccept() - if err != nil { return fmt.Errorf("error while writing logon accept: %w", err) } @@ -327,12 +338,22 @@ func (s *ServerSession) Run() { go func() { <-s.Ctx.Done() + if errors.Is(s.Ctx.Err(), ErrNeedsDisconnect) { + if err := s.conn.Write(&msgcontrol.Disconnect{ + Reason: "control requested disconnect", + }); err != nil { + slog.Error("error writing disconnect message", "err", err) + } + } + s.Slog().Info("session exiting", "err", context.Cause(s.Ctx), "peer", s.Peer.Debug()) s.server.RemoveSession(s) if s.conn != nil { - s.conn.mc.Close() + if err := s.conn.mc.Close(); err != nil { + slog.Error("failed to close metaconn", "err", err) + } } }() @@ -374,21 +395,22 @@ func (s *ServerSession) Run() { return nil }) - if err != nil { err = fmt.Errorf("could not send greets: %w", err) return } - //s.server.ForVisible(s, func(session *ServerSession) { - // // TODO this currently blocks and holds the lock, we should make Greet async as well - // - // // TODO there is no bubbling of errors, ignore? log? - // - // session.Greet(s) - // - // s.Greet(session) - //}) + if s.Expiry != (time.Time{}) { + go func() { + select { + case <-s.Ctx.Done(): + // FIXME on suspend/delay/wallclock change, this won't work properly, + // find a time-until api that deals with wall-clock differences + case <-time.After(time.Until(s.Expiry)): + s.Ccc(ErrNeedsDisconnect) + } + }() + } s.Slog().Info("established session") @@ -396,7 +418,6 @@ func (s *ServerSession) Run() { var m msgcontrol.ControlMessage m, err = s.conn.Read(0) - if err != nil { // TODO this currently removes the session on connection break; no resuming @@ -427,35 +448,15 @@ func (s *ServerSession) Run() { session.UpdateHomeRelay(s.Peer, msg.HomeRelay) }) case *msgcontrol.Pong: + s.Slog().Debug("received pong") // TODO + case *msgcontrol.LogonDeviceKey: + s.Slog().Debug("received after-logon logon device key, ignoring...") default: err = fmt.Errorf("received unknown type of message: %#v", msg) return } } - - time.Sleep(30 * time.Second) - - // TODO make other peers aware - - // for now, send a reject - //if err = s.conn.Write(&msgcontrol.LogonReject{ - // Reason: "dev: reject unambiguously", - // RetryStrategy: 0, - //}); err != nil { - // err = fmt.Errorf("error when sending reject: %w", err) - // return - //} - - return - - // TODO after Accept, we send the client peer and relay definitions, - // but we need to wait for the client to send their home relay and endpoints, - // before we'd (ideally) send a complete peer info to other clients. - // We will wait 10 seconds for this, before timing out and sending incomplete information. - - // TODO - panic("implement me") } func (s *ServerSession) Slog() *slog.Logger { diff --git a/types/dial/http.go b/types/dial/http.go index b94d611..b1193db 100644 --- a/types/dial/http.go +++ b/types/dial/http.go @@ -4,10 +4,12 @@ import ( "bufio" "context" "fmt" - "github.com/edup2p/common/types" "io" + "log/slog" "net/http" "time" + + "github.com/edup2p/common/types" ) // getPriv func() *key.NodePrivate, getSess func() *key.SessionPrivate, controlKey key.NodePublic @@ -27,17 +29,23 @@ func HTTP[T any](ctx context.Context, opts Opts, url, protocol string, makeClien req.Header.Set("Upgrade", protocol) req.Header.Set("Connection", "Upgrade") + closeNetConn := func() { + if err := netConn.Close(); err != nil { + slog.Error("error when closing netconn", "err", err) + } + } + if err := req.Write(brw); err != nil { - netConn.Close() + closeNetConn() return nil, fmt.Errorf("could not write http request: %w", err) } if err := brw.Flush(); err != nil { - netConn.Close() + closeNetConn() return nil, fmt.Errorf("could not flush http request: %w", err) } if err := netConn.SetReadDeadline(time.Now().Add(time.Second * 5)); err != nil { - netConn.Close() + closeNetConn() return nil, fmt.Errorf("could not set read deadline: %w", err) } resp, err := http.ReadResponse(brw.Reader, req) @@ -55,9 +63,8 @@ func HTTP[T any](ctx context.Context, opts Opts, url, protocol string, makeClien // At this point, we're speaking the protocol with the server. c, err := makeClient(ctx, netConn, brw, opts) - if err != nil { - netConn.Close() + closeNetConn() return nil, fmt.Errorf("failed to establish client: %w", err) } diff --git a/types/dial/http_server.go b/types/dial/http_server.go index 5c07e6f..cba1edf 100644 --- a/types/dial/http_server.go +++ b/types/dial/http_server.go @@ -4,11 +4,14 @@ import ( "bufio" "context" "fmt" - "github.com/edup2p/common/types" "log/slog" + "net" "net/http" "net/netip" "strings" + "time" + + "github.com/edup2p/common/types" ) type ProtocolServer interface { @@ -41,18 +44,38 @@ func HTTPHandler(s ProtocolServer, proto string) http.Handler { return } - defer netConn.Close() + if tcpConn, ok := netConn.(*net.TCPConn); ok { + if err := tcpConn.SetKeepAlive(true); err != nil { + s.Logger().Warn("set keep alive failed", "error", err, "peer", r.RemoteAddr) + } + + if err := tcpConn.SetKeepAlivePeriod(11 * time.Second); err != nil { + s.Logger().Warn("set keep alive period failed", "error", err, "peer", r.RemoteAddr) + } + } else { + s.Logger().Warn("could not get *net.TCPConn, to set keepalive", "peer", r.RemoteAddr) + } + + defer func() { + if err := netConn.Close(); err != nil { + slog.Error("error when closing netconn", "err", err) + } + }() // TODO re-add publickey frontloading? // pubKey := s.PublicKey() // "Relay-Public-Key: %s\r\n\r\n",pubKey.HexString() - fmt.Fprintf(brw, "HTTP/1.1 101 Switching Protocols\r\n"+ + if _, err := fmt.Fprintf(brw, "HTTP/1.1 101 Switching Protocols\r\n"+ "Upgrade: %s\r\n"+ "Connection: Upgrade\r\n\r\n", - up) + up); err != nil { + slog.Error("error when writing 101 response", "err", err) + } - brw.Flush() + if err := brw.Flush(); err != nil { + slog.Error("error when flushing 101 response", "err", err) + } remoteIPPort, _ := netip.ParseAddrPort(netConn.RemoteAddr().String()) diff --git a/types/dial/tcp.go b/types/dial/tcp.go index ea52654..e0efaed 100644 --- a/types/dial/tcp.go +++ b/types/dial/tcp.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "log/slog" "net" "net/netip" "time" @@ -13,7 +14,6 @@ import ( // WithTLS does a "full" dial, including TLS wrapping and CN checking func WithTLS(ctx context.Context, opts Opts) (net.Conn, error) { netConn, err := TCP(ctx, opts) - if err != nil { return nil, fmt.Errorf("tcp dial failed: %w", err) } @@ -34,6 +34,7 @@ func TLS(conn net.Conn, opts Opts) *tls.Conn { case opts.Domain != "": cfg.ServerName = opts.Domain default: + // We assume this is sane, else some upstream provider of the opt isn't proper with what it gives panic("TLS defined, but no domain provided") } @@ -45,7 +46,7 @@ func TCP(ctx context.Context, opts Opts) (net.Conn, error) { var err error - if opts.Addrs == nil || len(opts.Addrs) == 0 { + if len(opts.Addrs) == 0 { opts.Addrs, err = net.DefaultResolver.LookupNetIP(ctx, "ip", opts.Domain) if err != nil { return nil, fmt.Errorf("failed to lookup %s: %w", opts.Domain, err) @@ -73,13 +74,15 @@ func TCP(ctx context.Context, opts Opts) (net.Conn, error) { for _, addr := range opts.Addrs { ap := netip.AddrPortFrom(addr, opts.Port) go func() { - c, e := dialOneTCP(dialCtx, ap) + conn, err := dialOneTCP(dialCtx, ap) select { - case results <- dialResult{c: c, e: e}: + case results <- dialResult{c: conn, e: err}: case <-returned: - if c != nil { - c.Close() + if conn != nil { + if err := conn.Close(); err != nil { + slog.Error("failed to close tcp connection while multi-dialing", "err", err) + } } } }() @@ -117,6 +120,7 @@ func dialOneTCP(ctx context.Context, ap netip.AddrPort) (net.Conn, error) { var d net.Dialer d.LocalAddr = nil + d.KeepAlive = time.Second * 10 return d.DialContext(ctx, "tcp", ap.String()) } diff --git a/types/ifaces/actor.go b/types/ifaces/actor.go index 9736341..1e8b7a5 100644 --- a/types/ifaces/actor.go +++ b/types/ifaces/actor.go @@ -1,12 +1,14 @@ package ifaces import ( + "context" + "net/netip" + "time" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgactor" "github.com/edup2p/common/types/msgsess" "github.com/edup2p/common/types/stage" - "net/netip" - "time" ) type Actor interface { @@ -14,10 +16,12 @@ type Actor interface { Inbox() chan<- msgactor.ActorMessage + Ctx() context.Context + // Cancel this actor's context. Cancel() - // Close is called by the actor's Run loop when cancelled. + // Close is called by AfterFunc to clean up Close() } @@ -76,6 +80,7 @@ type TrafficManagerActor interface { SendMsgToDirect(ap netip.AddrPort, sess key.SessionPublic, m msgsess.SessionMessage) SendMsgToRelay(relay int64, node key.NodePublic, sess key.SessionPublic, m msgsess.SessionMessage) SendPingDirect(ap netip.AddrPort, peer key.NodePublic, session key.SessionPublic) + SendPingDirectWithID(ap netip.AddrPort, peer key.NodePublic, session key.SessionPublic, txid msgsess.TxID) OutConnUseAddrPort(peer key.NodePublic, ap netip.AddrPort) OutConnTrackHome(peer key.NodePublic) @@ -103,3 +108,9 @@ type SessionManagerActor interface { type EndpointManagerActor interface { Actor } + +// === + +type MDNSManagerActor interface { + Actor +} diff --git a/types/ifaces/control.go b/types/ifaces/control.go index 81be825..0aa5093 100644 --- a/types/ifaces/control.go +++ b/types/ifaces/control.go @@ -1,10 +1,13 @@ package ifaces import ( + "context" + "net/netip" + "time" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgcontrol" "github.com/edup2p/common/types/relay" - "net/netip" ) // ControlCallbacks are the possible updates that the control server wishes to inform the client about. @@ -13,7 +16,7 @@ type ControlCallbacks interface { AddPeer( peer key.NodePublic, homeRelay int64, endpoints []netip.AddrPort, session key.SessionPublic, - ip4 netip.Addr, ip6 netip.Addr, + ip4, ip6 netip.Addr, prop msgcontrol.Properties, ) error @@ -43,6 +46,9 @@ type ControlInterface interface { // // As it is a netip.Prefix, it also includes the expected ipv6 range that all peers will be on. IPv6() netip.Prefix + // Expiry of the current control session, defaults to zero-value if there is no expiry, + // or session is not connected. + Expiry() time.Time // UpdateEndpoints informs the server of any changes in STUN-resolved endpoints. This is a set-replace operation. UpdateEndpoints([]netip.AddrPort) error @@ -54,6 +60,8 @@ type ControlInterface interface { type ControlSession interface { ControlInterface + Context() context.Context + // InstallCallbacks installs the current session's callbacks to another interface. // // This interface will be informed of updates from the control server. diff --git a/types/ifaces/injectable.go b/types/ifaces/injectable.go new file mode 100644 index 0000000..b12f4ec --- /dev/null +++ b/types/ifaces/injectable.go @@ -0,0 +1,9 @@ +package ifaces + +import "net/netip" + +type Injectable interface { + Available() bool + + InjectPacket(from, to netip.AddrPort, pkt []byte) error +} diff --git a/types/ifaces/stage.go b/types/ifaces/stage.go index 49c6119..cccf593 100644 --- a/types/ifaces/stage.go +++ b/types/ifaces/stage.go @@ -1,9 +1,11 @@ package ifaces import ( + "context" + "net/netip" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/stage" - "net/netip" ) // Stage documents/iterates the functions a Stage should expose @@ -18,4 +20,6 @@ type Stage interface { GetPeerInfo(peer key.NodePublic) *stage.PeerInfo GetEndpoints() []netip.AddrPort + + Context() context.Context } diff --git a/types/key/bson.go b/types/key/bson.go deleted file mode 100644 index 1e204dd..0000000 --- a/types/key/bson.go +++ /dev/null @@ -1,50 +0,0 @@ -package key - -import ( - "encoding" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" -) - -func textMarshalBson(val encoding.TextMarshaler) (bsontype.Type, []byte, error) { - textBytes, err := val.MarshalText() - if err != nil { - return 0, nil, err - } - - return bson.MarshalValue(string(textBytes)) -} - -func textUnmarshalBson(val encoding.TextUnmarshaler, b bsontype.Type, bytes []byte) error { - var s = new(string) - - if err := bson.UnmarshalValue(b, bytes, s); err != nil { - return err - } - - return val.UnmarshalText([]byte(*s)) -} - -func (n *NodePublic) MarshalBSONValue() (bsontype.Type, []byte, error) { - return textMarshalBson(n) -} - -func (n *NodePublic) UnmarshalBSONValue(b bsontype.Type, bytes []byte) error { - return textUnmarshalBson(n, b, bytes) -} - -func (s *SessionPublic) MarshalBSONValue() (bsontype.Type, []byte, error) { - return textMarshalBson(s) -} - -func (s *SessionPublic) UnmarshalBSONValue(b bsontype.Type, bytes []byte) error { - return textUnmarshalBson(s, b, bytes) -} - -func (c *ControlPublic) MarshalBSONValue() (bsontype.Type, []byte, error) { - return textMarshalBson(c) -} - -func (c *ControlPublic) UnmarshalBSONValue(b bsontype.Type, bytes []byte) error { - return textUnmarshalBson(c, b, bytes) -} diff --git a/types/key/control_priv.go b/types/key/control_priv.go index c14cfa1..3dcc5f4 100644 --- a/types/key/control_priv.go +++ b/types/key/control_priv.go @@ -2,6 +2,7 @@ package key import ( "crypto/subtle" + "github.com/edup2p/common/types" "go4.org/mem" "golang.org/x/crypto/curve25519" diff --git a/types/key/control_pub.go b/types/key/control_pub.go index a83a726..19785e4 100644 --- a/types/key/control_pub.go +++ b/types/key/control_pub.go @@ -4,8 +4,9 @@ import ( "encoding/hex" "encoding/json" "fmt" - "go4.org/mem" "strings" + + "go4.org/mem" ) type ControlPublic NakedKey diff --git a/types/key/iface.go b/types/key/iface.go index eb0deb1..ffff6fb 100644 --- a/types/key/iface.go +++ b/types/key/iface.go @@ -9,7 +9,7 @@ type key interface { } type canTextMarshal interface { - // We need text encoding for JSON and BSON (currently) + // We need text encoding for JSON encoding.TextMarshaler encoding.TextUnmarshaler @@ -19,30 +19,18 @@ type canTextMarshal interface { // encoding.BinaryUnmarshaler } -//type canBsonMarshal interface { -// bson.ValueMarshaler -// bson.ValueUnmarshaler -// -// // TODO maybe also allow/support binary marshalling -// // encoding.BinaryMarshaler -// // encoding.BinaryUnmarshaler -//} - type publicKey interface { key IsZero() bool Debug() string HexString() string - // TODO } type privateKey[Pub key] interface { key Public() Pub - - // TODO } type canSealTo[To publicKey] interface { diff --git a/types/key/node_priv.go b/types/key/node_priv.go index 6d10ee4..f958882 100644 --- a/types/key/node_priv.go +++ b/types/key/node_priv.go @@ -4,10 +4,11 @@ import ( "crypto/subtle" "encoding/json" "fmt" + "strings" + "github.com/edup2p/common/types" "go4.org/mem" "golang.org/x/crypto/curve25519" - "strings" ) type NodePrivate struct { diff --git a/types/key/node_pub.go b/types/key/node_pub.go index 9a81a2e..cb9f252 100644 --- a/types/key/node_pub.go +++ b/types/key/node_pub.go @@ -4,8 +4,9 @@ import ( "encoding/hex" "encoding/json" "fmt" - "go4.org/mem" "strings" + + "go4.org/mem" ) type NodePublic NakedKey diff --git a/types/key/session_priv.go b/types/key/session_priv.go index 81d2281..64b1443 100644 --- a/types/key/session_priv.go +++ b/types/key/session_priv.go @@ -2,6 +2,7 @@ package key import ( "crypto/subtle" + "github.com/edup2p/common/types" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" diff --git a/types/key/session_pub.go b/types/key/session_pub.go index 693ecdc..9d18437 100644 --- a/types/key/session_pub.go +++ b/types/key/session_pub.go @@ -3,6 +3,7 @@ package key import ( "encoding/hex" "fmt" + "go4.org/mem" ) diff --git a/types/key/session_shared.go b/types/key/session_shared.go index 6d9d7c1..fae4bee 100644 --- a/types/key/session_shared.go +++ b/types/key/session_shared.go @@ -2,6 +2,7 @@ package key import ( "crypto/subtle" + "github.com/edup2p/common/types" "golang.org/x/crypto/nacl/box" ) diff --git a/types/key/util.go b/types/key/util.go index aa9c724..d101fec 100644 --- a/types/key/util.go +++ b/types/key/util.go @@ -5,10 +5,11 @@ import ( "encoding/hex" "errors" "fmt" - "go4.org/mem" - "golang.org/x/crypto/nacl/box" "io" "slices" + + "go4.org/mem" + "golang.org/x/crypto/nacl/box" ) // rand fills b with cryptographically strong random bytes. Panics if diff --git a/types/misc.go b/types/misc.go index 2e0565d..3e569a0 100644 --- a/types/misc.go +++ b/types/misc.go @@ -6,10 +6,11 @@ import ( "context" "crypto/rand" "encoding/hex" - "golang.org/x/exp/maps" "log/slog" "net/netip" "strings" + + "golang.org/x/exp/maps" ) // Incomparable is a zero-width incomparable type. If added as the @@ -72,12 +73,12 @@ func SliceOrEmpty[T any](v []T) []T { } func SliceOrNil[T any](v []T) []T { - if (v != nil && len(v) > 0) || (v == nil) { + if len(v) > 0 { return v - } else { - // len(v) == 0 - return nil } + + // len(v) == 0 + return nil } // IsContextDone does a quick check on a context to see if its dead. @@ -90,6 +91,7 @@ func RandStringBytesMaskImprSrc(n int) string { b := make([]byte, (n+1)/2) // can be simplified to n/2 if n is always even if _, err := rand.Read(b); err != nil { + // We expect the randomizer to be available here panic(err) } @@ -135,3 +137,7 @@ func Map[T, U any](ts []T, f func(T) U) []U { } type LogonCallback func(url string, deviceKey chan<- string) error + +type CCCKEY int + +const CCC CCCKEY = 112 diff --git a/types/msgactor/msg.go b/types/msgactor/msg.go index 7d47cc3..de0b3c4 100644 --- a/types/msgactor/msg.go +++ b/types/msgactor/msg.go @@ -1,11 +1,12 @@ package msgactor import ( + "net/netip" + "time" + "github.com/edup2p/common/types/key" "github.com/edup2p/common/types/msgsess" "github.com/edup2p/common/types/relay" - "net/netip" - "time" ) // Messages @@ -44,6 +45,11 @@ type TManSessionMessageFromDirect struct { Msg *msgsess.ClearMessage } +type TManSpreadMDNSPacket struct { + Pkt []byte + IP6 bool +} + // ====================================================================================================== // SessionManager msgs @@ -106,6 +112,17 @@ type RManRelayLatencyResults struct { RelayLatency map[int64]time.Duration } +// ====================================================================================================== +// MDNSManager msgs + +type MManReceivedPacket struct { + From key.NodePublic + + Data []byte + + IP6 bool +} + // ====================================================================================================== // DirectRouter msgs diff --git a/types/msgactor/msg_iface.go b/types/msgactor/msg_iface.go index d8e5d4b..1e2afd6 100644 --- a/types/msgactor/msg_iface.go +++ b/types/msgactor/msg_iface.go @@ -10,6 +10,7 @@ func (o *TManConnActivity) amsg() {} func (o *TManConnGoodBye) amsg() {} func (o *TManSessionMessageFromRelay) amsg() {} func (o *TManSessionMessageFromDirect) amsg() {} +func (o *TManSpreadMDNSPacket) amsg() {} func (o *SManSessionFrameFromRelay) amsg() {} func (o *SManSessionFrameFromAddrPort) amsg() {} @@ -19,6 +20,8 @@ func (o *OutConnUse) amsg() {} func (o *RManRelayLatencyResults) amsg() {} +func (o *MManReceivedPacket) amsg() {} + func (o *DManSetMTU) amsg() {} func (o *DRouterPeerClearKnownAs) amsg() {} func (o *DRouterPeerAddKnownAs) amsg() {} diff --git a/types/msgactor/notif.go b/types/msgactor/notif.go index e41d2b7..b646fdd 100644 --- a/types/msgactor/notif.go +++ b/types/msgactor/notif.go @@ -1,8 +1,9 @@ package msgactor import ( - "github.com/edup2p/common/types/key" "net/netip" + + "github.com/edup2p/common/types/key" ) type PeerState byte @@ -13,16 +14,19 @@ const ( PeerStateDirect ) +//nolint:unused type PeerConnStateChangeNotification struct { peer key.NodePublic state PeerState } +//nolint:unused type LocalEndpointsChangeNotification struct { endpoints []netip.AddrPort } +//nolint:unused type HomeRelayChangeNotification struct { homeRelay int64 } diff --git a/types/msgcontrol/msg.go b/types/msgcontrol/msg.go index 5e0292b..47c8ce9 100644 --- a/types/msgcontrol/msg.go +++ b/types/msgcontrol/msg.go @@ -1,10 +1,12 @@ package msgcontrol import ( - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/relay" + "fmt" "net/netip" "time" + + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/relay" ) type ControlMessageType byte @@ -17,7 +19,6 @@ const ( LogonDeviceKeyType LogonAcceptType LogonRejectType - LogoutType PingType PongType ) @@ -29,6 +30,8 @@ const ( PeerUpdateType PeerRemoveType RelayUpdateType + LogoutType + DisconnectType ) // === handshake phase @@ -72,6 +75,8 @@ type LogonAccept struct { IP4 netip.Prefix IP6 netip.Prefix + AuthExpiry time.Time + SessionID string } @@ -96,7 +101,7 @@ func (r RetryStrategyType) Error() string { case RecreateSession: return "retry by recreating session" default: - panic("unknown retry strategy type") + return fmt.Sprintf("!!!unknown retry strategy type %d!!!", r) } } @@ -108,8 +113,6 @@ type LogonReject struct { RetryAfter time.Duration `json:",omitempty"` } -type Logout struct{} - type Ping struct { // random data encrypted with shared key (control priv x client pub) // to be signed with shared key of nodekey and sesskey @@ -125,6 +128,18 @@ type Pong struct { // === during session +// -> control +type Logout struct{} + +// -> client +type Disconnect struct { + Reason string + + RetryStrategy RetryStrategyType `json:",omitempty"` + + RetryAfter time.Duration `json:",omitempty"` +} + // -> control type EndpointUpdate struct { Endpoints []netip.AddrPort diff --git a/types/msgcontrol/msg_iface.go b/types/msgcontrol/msg_iface.go index b429257..7fc4841 100644 --- a/types/msgcontrol/msg_iface.go +++ b/types/msgcontrol/msg_iface.go @@ -7,30 +7,35 @@ type ControlMessage interface { func (c *ClientHello) CMsgType() ControlMessageType { return ClientHelloType } + func (c *ServerHello) CMsgType() ControlMessageType { return ServerHelloType } + func (c *Logon) CMsgType() ControlMessageType { return LogonType } + func (c *LogonAuthenticate) CMsgType() ControlMessageType { return LogonAuthenticateType } + func (c *LogonDeviceKey) CMsgType() ControlMessageType { return LogonDeviceKeyType } + func (c *LogonAccept) CMsgType() ControlMessageType { return LogonAcceptType } + func (c *LogonReject) CMsgType() ControlMessageType { return LogonRejectType } -func (c *Logout) CMsgType() ControlMessageType { - return LogoutType -} + func (c *Ping) CMsgType() ControlMessageType { return PingType } + func (c *Pong) CMsgType() ControlMessageType { return PongType } @@ -38,18 +43,31 @@ func (c *Pong) CMsgType() ControlMessageType { func (c *EndpointUpdate) CMsgType() ControlMessageType { return EndpointUpdateType } + func (c *HomeRelayUpdate) CMsgType() ControlMessageType { return HomeRelayUpdateType } + func (c *PeerAddition) CMsgType() ControlMessageType { return PeerAdditionType } + func (c *PeerUpdate) CMsgType() ControlMessageType { return PeerUpdateType } + func (c *PeerRemove) CMsgType() ControlMessageType { return PeerRemoveType } + func (c *RelayUpdate) CMsgType() ControlMessageType { return RelayUpdateType } + +func (c *Logout) CMsgType() ControlMessageType { + return LogoutType +} + +func (c *Disconnect) CMsgType() ControlMessageType { + return DisconnectType +} diff --git a/types/msgsess/consts.go b/types/msgsess/consts.go index 04b3a6b..a137956 100644 --- a/types/msgsess/consts.go +++ b/types/msgsess/consts.go @@ -15,8 +15,9 @@ const v1 = VersionMarker(0x1) type MessageType byte const ( - PingMessage = MessageType(0x00) - PongMessage = MessageType(0x01) + PingMessage = MessageType(iota) + PongMessage + SideBandDataMessage RendezvousMessage = MessageType(0xFF) ) diff --git a/types/msgsess/msgsess.go b/types/msgsess/msgsess.go index 60c07c4..61067f0 100644 --- a/types/msgsess/msgsess.go +++ b/types/msgsess/msgsess.go @@ -6,10 +6,12 @@ package msgsess import "github.com/edup2p/common/types/key" type SessionMessage interface { - MarshalSessionMessage() []byte + Marshal() []byte // todo maybe convert to slog.Group? Debug() string + + Parse([]byte) error } // ClearMessage represents a full session wire message in decrypted view diff --git a/types/msgsess/parsing.go b/types/msgsess/parsing.go index 01daa6f..ebebee0 100644 --- a/types/msgsess/parsing.go +++ b/types/msgsess/parsing.go @@ -3,9 +3,8 @@ package msgsess import ( "errors" "fmt" - "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" - "net/netip" ) // Session Wire header: @@ -35,64 +34,26 @@ func ParseSessionMessage(usrMsg []byte) (SessionMessage, error) { return nil, fmt.Errorf("invalid version: %x", version) } + var msg SessionMessage + switch MessageType(msgType) { case PingMessage: - return parsePing(specificMsg) + msg = new(Ping) case PongMessage: - return parsePong(specificMsg) + msg = new(Pong) case RendezvousMessage: - return parseRendezvous(specificMsg) + msg = new(Rendezvous) + case SideBandDataMessage: + msg = new(SideBandData) default: return nil, fmt.Errorf("invalid message type: %x", msgType) } -} - -var errTooSmall = errors.New("session message too small") - -func parsePing(b []byte) (*Ping, error) { - if len(b) < key.Len+12 { - return nil, errTooSmall - } - - txid := [12]byte(b[:12]) - b = b[12:] - nKey := key.NodePublic(b[:key.Len]) - - return &Ping{ - TxID: txid, - NodeKey: nKey, - }, nil -} -func parsePong(b []byte) (*Pong, error) { - if len(b) < 12+16+2 { - return nil, errTooSmall + if err := msg.Parse(specificMsg); err != nil { + return nil, fmt.Errorf("failed to parse message of type %d: %w", msgType, err) } - txid := [12]byte(b[:12]) - b = b[12:] - - ap := types.ParseAddrPort([18]byte(b[:18])) - - return &Pong{TxID: txid, Src: ap}, nil + return msg, nil } -func parseRendezvous(b []byte) (*Rendezvous, error) { - if len(b)%18 != 0 { - return nil, errors.New("malformed rendezvous addresses") - } - - aps := make([]netip.AddrPort, 0) - - for { - ap := types.ParseAddrPort([18]byte(b[:18])) - aps = append(aps, ap) - b = b[18:] - - if len(b) == 0 { - break - } - } - - return &Rendezvous{MyAddresses: aps}, nil -} +var errTooSmall = errors.New("session message too small") diff --git a/types/msgsess/ping.go b/types/msgsess/ping.go index d28de62..33f99e8 100644 --- a/types/msgsess/ping.go +++ b/types/msgsess/ping.go @@ -3,8 +3,9 @@ package msgsess import ( crand "crypto/rand" "fmt" - "github.com/edup2p/common/types/key" "slices" + + "github.com/edup2p/common/types/key" ) type TxID [12]byte @@ -12,6 +13,7 @@ type TxID [12]byte func NewTxID() TxID { var tx TxID if _, err := crand.Read(tx[:]); err != nil { + // We expect the randomiser to be available here panic(err) } return tx @@ -27,10 +29,25 @@ type Ping struct { Padding int } -func (p *Ping) MarshalSessionMessage() []byte { +func (p *Ping) Marshal() []byte { + // TODO add padding return slices.Concat([]byte{byte(v1), byte(PingMessage)}, p.TxID[:], p.NodeKey[:]) } +func (p *Ping) Parse(b []byte) error { + if len(b) < key.Len+12 { + return errTooSmall + } + + p.TxID = [12]byte(b[:12]) + b = b[12:] + p.NodeKey = key.NodePublic(b[:key.Len]) + + // TODO count remaining bytes as padding + + return nil +} + func (p *Ping) Debug() string { return fmt.Sprintf("ping tx=%x nodekey=%s padding=%v", p.TxID, p.NodeKey.Debug(), p.Padding) } diff --git a/types/msgsess/pong.go b/types/msgsess/pong.go index b011c33..210d90b 100644 --- a/types/msgsess/pong.go +++ b/types/msgsess/pong.go @@ -2,9 +2,10 @@ package msgsess import ( "fmt" - "github.com/edup2p/common/types" "net/netip" "slices" + + "github.com/edup2p/common/types" ) type Pong struct { @@ -13,10 +14,23 @@ type Pong struct { Src netip.AddrPort // 18 bytes (16+2) on the wire; v4-mapped ipv6 for IPv4 } -func (p *Pong) MarshalSessionMessage() []byte { +func (p *Pong) Marshal() []byte { return slices.Concat([]byte{byte(v1), byte(PongMessage)}, p.TxID[:], types.PutAddrPort(p.Src)) } +func (p *Pong) Parse(b []byte) error { + if len(b) < 12+16+2 { + return errTooSmall + } + + p.TxID = [12]byte(b[:12]) + b = b[12:] + + p.Src = types.ParseAddrPort([18]byte(b[:18])) + + return nil +} + func (p *Pong) Debug() string { return fmt.Sprintf("pong tx=%x src=%s", p.TxID, p.Src.String()) } diff --git a/types/msgsess/rendezvous.go b/types/msgsess/rendezvous.go index b4d7d56..b7f63cc 100644 --- a/types/msgsess/rendezvous.go +++ b/types/msgsess/rendezvous.go @@ -1,17 +1,19 @@ package msgsess import ( + "errors" "fmt" - "github.com/edup2p/common/types" "net/netip" "slices" + + "github.com/edup2p/common/types" ) type Rendezvous struct { MyAddresses []netip.AddrPort } -func (r *Rendezvous) MarshalSessionMessage() []byte { +func (r *Rendezvous) Marshal() []byte { b := make([]byte, 0) for _, ap := range r.MyAddresses { @@ -21,6 +23,28 @@ func (r *Rendezvous) MarshalSessionMessage() []byte { return slices.Concat([]byte{byte(v1), byte(RendezvousMessage)}, b) } +func (r *Rendezvous) Parse(b []byte) error { + if len(b)%18 != 0 { + return errors.New("malformed rendezvous addresses") + } + + aps := make([]netip.AddrPort, 0) + + for { + ap := types.ParseAddrPort([18]byte(b[:18])) + aps = append(aps, ap) + b = b[18:] + + if len(b) == 0 { + break + } + } + + r.MyAddresses = aps + + return nil +} + func (r *Rendezvous) Debug() string { return fmt.Sprintf("rendezvous addresses=%s", types.PrettyAddrPortSlice(r.MyAddresses)) } diff --git a/types/msgsess/sidebanddata.go b/types/msgsess/sidebanddata.go new file mode 100644 index 0000000..8adc063 --- /dev/null +++ b/types/msgsess/sidebanddata.go @@ -0,0 +1,42 @@ +package msgsess + +import ( + "fmt" + "slices" +) + +type SideBandDataType byte + +const ( + MDNSv4Type SideBandDataType = iota + MDNSv6Type SideBandDataType = iota +) + +type SideBandData struct { + Type SideBandDataType + Data []byte +} + +func (s *SideBandData) Marshal() []byte { + b := make([]byte, 0) + + b = append(b, byte(s.Type)) + b = append(b, s.Data...) + + return slices.Concat([]byte{byte(v1), byte(SideBandDataMessage)}, b) +} + +func (s *SideBandData) Parse(b []byte) error { + if len(b) < 1 { + return errTooSmall + } + + s.Type = SideBandDataType(b[0]) + s.Data = b[1:] + + return nil +} + +func (s *SideBandData) Debug() string { + return fmt.Sprintf("sidebanddata type=%d data=%x", s.Type, s.Data) +} diff --git a/types/relay/client.go b/types/relay/client.go index 9399aed..0e83bed 100644 --- a/types/relay/client.go +++ b/types/relay/client.go @@ -6,14 +6,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgsess" "io" "log/slog" "slices" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) const ( @@ -28,8 +29,21 @@ var ( errKeepAliveNonZeroLen = errors.New("keepalive frame has non-zero length") ) -// Client is a Relay client that lives as long as its conn does -type Client struct { +type Client interface { + Run() + RelayKey() key.NodePublic + + Send() chan<- SendPacket + Recv() <-chan RecvPacket + Done() <-chan struct{} + Err() error + + Close() + Cancel(error) +} + +// HTTPClient is a Relay client that lives as long as its conn does +type HTTPClient struct { ctx context.Context ccc context.CancelCauseFunc @@ -51,19 +65,19 @@ type Client struct { closed bool } -func (c *Client) Send() chan<- SendPacket { +func (c *HTTPClient) Send() chan<- SendPacket { return c.sendCh } -func (c *Client) Recv() <-chan RecvPacket { +func (c *HTTPClient) Recv() <-chan RecvPacket { return c.recvCh } -func (c *Client) Done() <-chan struct{} { +func (c *HTTPClient) Done() <-chan struct{} { return c.ctx.Done() } -func (c *Client) Err() error { +func (c *HTTPClient) Err() error { return c.ctx.Err() } @@ -81,14 +95,14 @@ type RecvPacket struct { Data []byte } -// EstablishClient creates a new relay.Client on a given MetaConn with associated bufio.ReadWriter. +// EstablishClient creates a new relay.HTTPClient on a given MetaConn with associated bufio.ReadWriter. // -// It logs in and authenticates the server before returning a Client object. +// It logs in and authenticates the server before returning a HTTPClient object. // If any error occurs, or no client can be established before timeout, it returns. -func EstablishClient(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, timeout time.Duration, getPriv func() *key.NodePrivate) (*Client, error) { +func EstablishClient(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, timeout time.Duration, getPriv func() *key.NodePrivate) (*HTTPClient, error) { ctx, ccc := context.WithCancelCause(parentCtx) - c := &Client{ + c := &HTTPClient{ ctx: ctx, ccc: ccc, @@ -139,36 +153,33 @@ func EstablishClient(parentCtx context.Context, mc types.MetaConn, brw *bufio.Re return nil, fmt.Errorf("could not reset deadline: %w", err) } - go func() { - <-c.ctx.Done() - c.Close() - }() + context.AfterFunc(c.ctx, c.Close) return c, nil } -func (c *Client) privateKey() *key.NodePrivate { +func (c *HTTPClient) privateKey() *key.NodePrivate { return c.getPriv() } -func (c *Client) publicKey() key.NodePublic { +func (c *HTTPClient) publicKey() key.NodePublic { return c.privateKey().Public() } // RelayKey returns the key of the relay we're connected to. -func (c *Client) RelayKey() key.NodePublic { +func (c *HTTPClient) RelayKey() key.NodePublic { return c.relayServerKey } // recvVersion assumes the caller has ownership, or lock -func (c *Client) recvVersion() (ProtocolVersion, error) { +func (c *HTTPClient) recvVersion() (ProtocolVersion, error) { b, err := c.reader.ReadByte() return ProtocolVersion(b), err } // recvServerKey assumes the caller has ownership, or lock -func (c *Client) recvServerKey() error { +func (c *HTTPClient) recvServerKey() error { frTyp, frLen, err := readFrameHeader(c.reader) if err != nil { return err @@ -187,7 +198,6 @@ func (c *Client) recvServerKey() error { var buf [32]byte _, err = io.ReadFull(c.reader, buf[:]) - if err != nil { return err } @@ -198,7 +208,7 @@ func (c *Client) recvServerKey() error { } // sendClientInfo assumes the caller has ownership, or lock -func (c *Client) sendClientInfo() error { +func (c *HTTPClient) sendClientInfo() error { m, err := json.Marshal(ClientInfo{SendKeepalive: true}) if err != nil { return err @@ -222,7 +232,7 @@ func (c *Client) sendClientInfo() error { } // recvServerInfo assumes the caller has ownership, or lock -func (c *Client) recvServerInfo() (*ServerInfo, error) { +func (c *HTTPClient) recvServerInfo() (*ServerInfo, error) { frTyp, frLen, err := readFrameHeader(c.reader) if err != nil { return nil, err @@ -238,7 +248,7 @@ func (c *Client) recvServerInfo() (*ServerInfo, error) { return nil, errPacketTooLarge } - var msgbox = make([]byte, frLen) + msgbox := make([]byte, frLen) if _, err = io.ReadFull(c.reader, msgbox); err != nil { return nil, err @@ -259,30 +269,37 @@ func (c *Client) recvServerInfo() (*ServerInfo, error) { return info, nil } -func (c *Client) Cancel(err error) { +func (c *HTTPClient) Cancel(err error) { c.ccc(err) if err := c.mc.SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil { slog.Error("could not set deadline in Cancel", "err", err) } } -func (c *Client) Close() { - if c.closed || context.Cause(c.ctx) != nil { +func (c *HTTPClient) Close() { + if c.closed { return } - c.mc.Close() + if err := c.mc.Close(); err != nil { + slog.Error("error when closing metaconn", "err", err) + } close(c.sendCh) close(c.recvCh) c.closed = true } -func (c *Client) Closed() bool { +func (c *HTTPClient) Closed() bool { return c.closed } -func (c *Client) RunReceive() { +func (c *HTTPClient) Run() { + go c.RunReceive() + go c.RunSend() +} + +func (c *HTTPClient) RunReceive() { if !c.recvMutex.TryLock() { slog.Error("could not lock recvMutex, is RunReceive already running?") return @@ -291,7 +308,7 @@ func (c *Client) RunReceive() { defer func() { if v := recover(); v != nil { - c.ccc(fmt.Errorf("reader panicked: %s", v)) + c.Cancel(fmt.Errorf("reader panicked: %s", v)) } }() @@ -304,11 +321,8 @@ func (c *Client) RunReceive() { for { frTyp, frLen, err = readFrameHeader(c.reader) - select { - case <-c.ctx.Done(): + if c.ctx.Err() != nil { return - default: - // fallthrough } if err != nil { @@ -357,7 +371,7 @@ func (c *Client) RunReceive() { } } -func (c *Client) RunSend() { +func (c *HTTPClient) RunSend() { if !c.sendMutex.TryLock() { slog.Error("could not lock sendMutex, is RunSend already running?") return diff --git a/types/relay/frame.go b/types/relay/frame.go index 6410772..43e9510 100644 --- a/types/relay/frame.go +++ b/types/relay/frame.go @@ -2,6 +2,7 @@ package relay import ( "bufio" + "github.com/edup2p/common/types" ) diff --git a/types/relay/info.go b/types/relay/info.go index 3a72e9c..0f29767 100644 --- a/types/relay/info.go +++ b/types/relay/info.go @@ -1,8 +1,9 @@ package relay import ( - "github.com/edup2p/common/types/key" "net/netip" + + "github.com/edup2p/common/types/key" ) type Information struct { diff --git a/types/relay/relayhttp/http_client.go b/types/relay/relayhttp/http_client.go index c2d72e0..4ac11a1 100644 --- a/types/relay/relayhttp/http_client.go +++ b/types/relay/relayhttp/http_client.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "github.com/edup2p/common/types" "github.com/edup2p/common/types/dial" "github.com/edup2p/common/types/key" @@ -22,10 +23,12 @@ func makeRelayURL(opts dial.Opts) string { return fmt.Sprintf("%s://%s/relay", proto, domain) } -func Dial(ctx context.Context, opts dial.Opts, getPriv func() *key.NodePrivate, expectKey key.NodePublic) (*relay.Client, error) { +type RelayDialFunc func(ctx context.Context, opts dial.Opts, getPriv func() *key.NodePrivate, expectKey key.NodePublic) (relay.Client, error) + +func Dial(ctx context.Context, opts dial.Opts, getPriv func() *key.NodePrivate, expectKey key.NodePublic) (relay.Client, error) { opts.SetDefaults() - c, err := dial.HTTP(ctx, opts, makeRelayURL(opts), relay.UpgradeProtocol, func(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, opts dial.Opts) (*relay.Client, error) { + c, err := dial.HTTP(ctx, opts, makeRelayURL(opts), relay.UpgradeProtocol, func(parentCtx context.Context, mc types.MetaConn, brw *bufio.ReadWriter, opts dial.Opts) (*relay.HTTPClient, error) { return relay.EstablishClient(parentCtx, mc, brw, opts.EstablishTimeout, getPriv) }) if err != nil { @@ -33,9 +36,10 @@ func Dial(ctx context.Context, opts dial.Opts, getPriv func() *key.NodePrivate, } if !expectKey.IsZero() && c.RelayKey() != expectKey { - c.Close() + err = fmt.Errorf("relay key did not match expected key") + c.Cancel(err) - return nil, fmt.Errorf("relay key did not match expected key") + return nil, err } return c, nil diff --git a/types/relay/relayhttp/http_server.go b/types/relay/relayhttp/http_server.go index d557f47..4389d0e 100644 --- a/types/relay/relayhttp/http_server.go +++ b/types/relay/relayhttp/http_server.go @@ -1,9 +1,10 @@ package relayhttp import ( + "net/http" + "github.com/edup2p/common/types/dial" "github.com/edup2p/common/types/relay" - "net/http" ) func ServerHandler(s *relay.Server) http.Handler { diff --git a/types/relay/server.go b/types/relay/server.go index ef5ddbc..d2d533e 100644 --- a/types/relay/server.go +++ b/types/relay/server.go @@ -6,14 +6,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgsess" "io" "log/slog" "net/netip" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) type Server struct { @@ -61,7 +62,6 @@ func (s *Server) sendServerKey(writer *bufio.Writer) (err error) { pKey := s.PublicKey() _, err = writer.Write(pKey[:]) - if err != nil { return } @@ -212,7 +212,6 @@ func (s *Server) getClient(peer key.NodePublic) *ServerClient { } func (s *Server) registerClient(client *ServerClient) { - // Check if there's a client active on this key already. if sc := s.getClient(client.nodeKey); sc != nil { // Just cancel the old connected client. diff --git a/types/relay/serverclient.go b/types/relay/serverclient.go index 206f8b0..d9d08c8 100644 --- a/types/relay/serverclient.go +++ b/types/relay/serverclient.go @@ -5,14 +5,15 @@ import ( "context" "errors" "fmt" - "github.com/edup2p/common/types" - "github.com/edup2p/common/types/key" - "github.com/edup2p/common/types/msgsess" "io" "log/slog" "math/rand" "net/netip" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "github.com/edup2p/common/types/msgsess" ) // ServerPacket is a transient packet type handled by the server @@ -113,7 +114,6 @@ func (sc *ServerClient) RunReceiver() { for { frType, frLen, err := readFrameHeader(sc.buffReader) - if err != nil { if errors.Is(err, io.EOF) { sc.ccc(fmt.Errorf("reader: read EOF")) @@ -147,7 +147,6 @@ func (sc *ServerClient) RunReceiver() { } func (sc *ServerClient) handleSend(frLen uint32) error { - dstKey, contents, err := sc.readSend(frLen) if err != nil { return err @@ -292,7 +291,9 @@ func (sc *ServerClient) RunSender() { } func (sc *ServerClient) setWriteDeadline() { - sc.netConn.SetWriteDeadline(time.Now().Add(ServerClientWriteTimeout)) + if err := sc.netConn.SetWriteDeadline(time.Now().Add(ServerClientWriteTimeout)); err != nil { + slog.Error("setWriteDeadline error", "err", err) + } } // sendKeepAlive sends a keep-alive frame, without flushing. diff --git a/types/stage/stage.go b/types/stage/stage.go index b32779e..2787325 100644 --- a/types/stage/stage.go +++ b/types/stage/stage.go @@ -2,9 +2,10 @@ package stage import ( - "github.com/edup2p/common/types/key" "net/netip" "time" + + "github.com/edup2p/common/types/key" ) type SentPing struct { @@ -20,4 +21,6 @@ type PeerInfo struct { Endpoints []netip.AddrPort RendezvousEndpoints []netip.AddrPort Session key.SessionPublic + IPv4, IPv6 netip.Addr + MDNS bool } diff --git a/types/stun/response.go b/types/stun/response.go index 3aad864..5a761ad 100644 --- a/types/stun/response.go +++ b/types/stun/response.go @@ -90,7 +90,6 @@ func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { } } return nil - }); err != nil { return TxID{}, netip.AddrPort{}, err } diff --git a/types/stun/server.go b/types/stun/server.go index 44c7e4e..769dcd2 100644 --- a/types/stun/server.go +++ b/types/stun/server.go @@ -12,8 +12,8 @@ import ( ) type Server struct { - ctx context.Context // ctx signals service shutdown - pc *net.UDPConn // pc is the UDP listener + ctx context.Context // ctx signals service shutdown + bind *net.UDPConn // bind is the UDP listener } func NewServer(ctx context.Context) *Server { @@ -24,7 +24,7 @@ func (s *Server) Listen(addrPort netip.AddrPort) error { ua := net.UDPAddrFromAddrPort(addrPort) var err error - s.pc, err = net.ListenUDP("udp", ua) + s.bind, err = net.ListenUDP("udp", ua) if err != nil { return err } @@ -32,14 +32,16 @@ func (s *Server) Listen(addrPort netip.AddrPort) error { // close the listener on shutdown in order to break out of the read loop go func() { <-s.ctx.Done() - s.pc.Close() + if err := s.bind.Close(); err != nil { + slog.Error("failed to close bind", "err", err) + } }() return nil } // LocalAddr returns the local address of the STUN server. It must not be called before ListenAndServe. func (s *Server) LocalAddr() net.Addr { - return s.pc.LocalAddr() + return s.bind.LocalAddr() } func (s *Server) Serve() error { @@ -50,8 +52,7 @@ func (s *Server) Serve() error { err error ) for { - n, ua, err = s.pc.ReadFromUDP(buf[:]) - + n, ua, err = s.bind.ReadFromUDP(buf[:]) if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { return nil @@ -76,7 +77,7 @@ func (s *Server) Serve() error { addr, _ := netip.AddrFromSlice(ua.IP) res := Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) - if _, err = s.pc.WriteTo(res, ua); err != nil { + if _, err = s.bind.WriteTo(res, ua); err != nil { slog.Info("writing back STUN response failed", "error", err) } } diff --git a/types/stun/txid.go b/types/stun/txid.go index 02d0ac3..2bf59ba 100644 --- a/types/stun/txid.go +++ b/types/stun/txid.go @@ -9,6 +9,7 @@ type TxID [12]byte func NewTxID() TxID { var tx TxID if _, err := crand.Read(tx[:]); err != nil { + // We expect the randomizer to be available here panic(err) } return tx diff --git a/usrwg/bind.go b/usrwg/bind.go index 144b8cb..654929b 100644 --- a/usrwg/bind.go +++ b/usrwg/bind.go @@ -1,15 +1,20 @@ package usrwg import ( + "context" "errors" "fmt" - "github.com/edup2p/common/types/key" - "golang.org/x/exp/maps" - "golang.zx2c4.com/wireguard/conn" + "log/slog" + "net" "reflect" "runtime" "sync" "time" + + "github.com/edup2p/common/types" + "github.com/edup2p/common/types/key" + "golang.org/x/exp/maps" + "golang.zx2c4.com/wireguard/conn" ) type ToverSokBind struct { @@ -17,6 +22,8 @@ type ToverSokBind struct { conns map[key.NodePublic]*ChannelConn connChange chan bool + permClosed bool + endpointMu sync.RWMutex endpoints map[key.NodePublic]*endpoint } @@ -44,18 +51,36 @@ func (b *ToverSokBind) Close() error { maps.Clear(b.endpoints) + var errs []error + for _, cc := range b.conns { - // TODO log error - cc.Close() + if err := cc.Close(); err != nil { + errs = append(errs, err) + } } maps.Clear(b.conns) + if len(errs) > 0 { + return fmt.Errorf("errors when closing connections: %w", errors.Join(errs...)) + } + + b.notifyConnChange() + return nil } +func (b *ToverSokBind) Cancel() error { + b.permClosed = true + return b.Close() +} + // ReadFromConns implements conn.ReceiveFunc func (b *ToverSokBind) ReadFromConns(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + if b.isPermanentlyClosed() { + return 0, net.ErrClosed + } + // We get a keys slice that could potentially get immediately outdated, // but we use it to fill buffers from existing conns first. b.connMu.RLock() @@ -94,15 +119,9 @@ fill: sizes[i] = len(p) copy(packets[i], p) - n += 1 + n++ } - //defer func() { - // for i := 0; i < n; i++ { - // slog.Debug("received packet", "hex", hex.EncodeToString(packets[i][:sizes[i]])) - // } - //}() - if n != 0 { // Buffer filled, return early return @@ -131,6 +150,10 @@ fill: return } +func (b *ToverSokBind) isPermanentlyClosed() bool { + return b.permClosed +} + func (b *ToverSokBind) waitForValueFromConns() ([]byte, *endpoint) { caseMap := b.buildConnsSelectCaseMap() connChangeCase := b.createConnChangeSelectCase() @@ -144,9 +167,18 @@ func (b *ToverSokBind) waitForValueFromConns() ([]byte, *endpoint) { cases = append(cases, connChangeCase) - choice, recv, _ := reflect.Select(cases) + choice, recv, recvOk := reflect.Select(cases) - //slog.Debug("waitForValueFromConns reflect.Select", "choice", choice, "len", len(cases), "recv", recv, "recvOk", recvOk, "cases", cases) + slog.Log( + context.Background(), + types.LevelTrace, + "waitForValueFromConns reflect.Select", + "choice", choice, + "len", len(cases), + "recv", recv, + "recvOk", recvOk, + "cases", cases, + ) // choice == last index if choice == len(cases)-1 { @@ -158,7 +190,7 @@ func (b *ToverSokBind) waitForValueFromConns() ([]byte, *endpoint) { } func (b *ToverSokBind) buildConnsSelectCaseMap() map[key.NodePublic]reflect.SelectCase { - var cases = make(map[key.NodePublic]reflect.SelectCase) + cases := make(map[key.NodePublic]reflect.SelectCase) b.connMu.RLock() defer b.connMu.RUnlock() @@ -188,7 +220,7 @@ func (b *ToverSokBind) createConnChangeSelectCase() reflect.SelectCase { // SetMark is used by wireguard-go to avoid routing loops. // TODO: double-check -func (b *ToverSokBind) SetMark(mark uint32) error { +func (b *ToverSokBind) SetMark(uint32) error { return nil } @@ -214,7 +246,6 @@ func (b *ToverSokBind) Send(bufs [][]byte, ep conn.Endpoint) error { func (b *ToverSokBind) ParseEndpoint(s string) (conn.Endpoint, error) { np, err := key.UnmarshalPublic(s) - if err != nil { return nil, fmt.Errorf("failed to unmarshal nodepublic: %w", err) } @@ -266,8 +297,9 @@ func (b *ToverSokBind) CloseConn(peer key.NodePublic) { cc, ok := b.conns[peer] if ok { - // TODO log error - cc.Close() + if err := cc.Close(); err != nil { + slog.Error("failed to close channel", "peer", peer, "err", err) + } } delete(b.conns, peer) diff --git a/usrwg/channel_conn.go b/usrwg/channel_conn.go index dbf14b8..89581df 100644 --- a/usrwg/channel_conn.go +++ b/usrwg/channel_conn.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "sync" "time" ) @@ -17,6 +18,8 @@ type ChannelConn struct { // Packets written by the frontend outgoing chan []byte + doClose sync.Once + currentReadDeadline time.Time } @@ -24,9 +27,8 @@ const ChannelConnBufferSize = 16 func makeChannelConn() *ChannelConn { return &ChannelConn{ - incoming: make(chan []byte, ChannelConnBufferSize), - outgoing: make(chan []byte, ChannelConnBufferSize), - currentReadDeadline: time.Time{}, + incoming: make(chan []byte, ChannelConnBufferSize), + outgoing: make(chan []byte, ChannelConnBufferSize), } } @@ -64,15 +66,15 @@ func (cc *ChannelConn) Write(b []byte) (int, error) { return len(b), nil } -func (cc *ChannelConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { +func (cc *ChannelConn) WriteToUDPAddrPort(_ []byte, _ netip.AddrPort) (int, error) { return 0, net.ErrWriteToConnected } func (cc *ChannelConn) Close() error { - // TODO boolean to check if is already closed? - - close(cc.outgoing) - close(cc.incoming) + cc.doClose.Do(func() { + close(cc.outgoing) + close(cc.incoming) + }) return nil } @@ -92,7 +94,7 @@ func (cc *ChannelConn) tryGetOut() (pkt []byte) { // Reads a packet from the outgoing channel, and waits. // // Returns nil on timeout. -func (cc *ChannelConn) getOut(d time.Duration) (pkt []byte) { +func (cc *ChannelConn) getOut(d time.Duration) (pkt []byte) { // nolint:unused select { case pkt = <-cc.outgoing: case <-time.After(d): @@ -105,7 +107,6 @@ func (cc *ChannelConn) getOut(d time.Duration) (pkt []byte) { // // Will return false on timeout. func (cc *ChannelConn) putIn(pkt []byte, d time.Duration) (ok bool) { - select { case cc.incoming <- pkt: ok = true diff --git a/usrwg/endpoint.go b/usrwg/endpoint.go index bb90fcf..dce21d2 100644 --- a/usrwg/endpoint.go +++ b/usrwg/endpoint.go @@ -1,9 +1,10 @@ package usrwg import ( - "github.com/edup2p/common/types/key" "net/netip" "slices" + + "github.com/edup2p/common/types/key" ) type endpoint struct { diff --git a/usrwg/router/router_bsd.go b/usrwg/router/router_bsd.go index a1d520f..ba486b0 100644 --- a/usrwg/router/router_bsd.go +++ b/usrwg/router/router_bsd.go @@ -4,16 +4,16 @@ package router import ( "fmt" - "go4.org/netipx" - "golang.zx2c4.com/wireguard/tun" "log/slog" "net/netip" "runtime" + + "go4.org/netipx" + "golang.zx2c4.com/wireguard/tun" ) func NewRouter(device tun.Device) (Router, error) { name, err := device.Name() - if err != nil { return nil, err } @@ -105,13 +105,16 @@ func (r *bsdRouter) removeAddr(prefix netip.Prefix) error { func (r *bsdRouter) addRoute(prefix netip.Prefix) error { net := netipx.PrefixIPNet(prefix) - // TODO replace with .Masked()? + // TODO replace with (Prefix).Masked()? + // need to figure out what the exact outputs are, and if .Masked does that nip := net.IP.Mask(net.Mask) nstr := fmt.Sprintf("%v/%d", nip, prefix.Bits()) - args := []string{"route", "-q", "-n", + args := []string{ + "route", "-q", "-n", "add", "-" + inet(prefix), nstr, - "-iface", r.tunName} + "-iface", r.tunName, + } if out, err := cmd(args...).CombinedOutput(); err != nil { return fmt.Errorf("route add failed: %v => %w\n%s", args, err, out) @@ -122,16 +125,19 @@ func (r *bsdRouter) addRoute(prefix netip.Prefix) error { func (r *bsdRouter) removeRoute(prefix netip.Prefix) error { net := netipx.PrefixIPNet(prefix) - // TODO replace with .Masked()? + // TODO replace with (Prefix).Masked()? + // need to figure out what the exact outputs are, and if .Masked does that nip := net.IP.Mask(net.Mask) nstr := fmt.Sprintf("%v/%d", nip, prefix.Bits()) del := "del" if runtime.GOOS == "darwin" { del = "delete" } - routedel := []string{"route", "-q", "-n", + routedel := []string{ + "route", "-q", "-n", del, "-" + inet(prefix), nstr, - "-iface", r.tunName} + "-iface", r.tunName, + } if out, err := cmd(routedel...).CombinedOutput(); err != nil { return fmt.Errorf("route del failed: %v: %w\n%s", routedel, err, out) diff --git a/usrwg/router/router_linux.go b/usrwg/router/router_linux.go index 263f8b8..71a2126 100644 --- a/usrwg/router/router_linux.go +++ b/usrwg/router/router_linux.go @@ -2,14 +2,14 @@ package router import ( "fmt" - "golang.zx2c4.com/wireguard/tun" "log/slog" "net/netip" + + "golang.zx2c4.com/wireguard/tun" ) func NewRouter(device tun.Device) (Router, error) { name, err := device.Name() - if err != nil { return nil, err } diff --git a/usrwg/router/router_windows.go b/usrwg/router/router_windows.go index 2f9da0a..cd411e8 100644 --- a/usrwg/router/router_windows.go +++ b/usrwg/router/router_windows.go @@ -6,26 +6,26 @@ package router import ( "errors" "fmt" - "golang.org/x/sys/windows/svc" + "log" "log/slog" + "net/netip" "os" "os/exec" "path/filepath" + "slices" + "sort" "sync" "syscall" + "time" "github.com/dblohm7/wingoes/com" "github.com/edup2p/common/usrwg/router/winnet" - ole "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole" "go4.org/netipx" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "log" - "net/netip" - "slices" - "sort" - "time" ) func init() { @@ -42,7 +42,8 @@ func init() { func isWindowsService() bool { v, err := svc.IsWindowsService() if err != nil { - log.Fatalf("svc.IsWindowsService failed: %v", err) + // Expect that we can at least poke the local windows service, else we're in trouble. + panic(fmt.Errorf("svc.IsWindowsService failed: %w", err)) } return v } @@ -55,7 +56,6 @@ func NewRouter(device tun.Device) (Router, error) { luid := winipcfg.LUID(nativeTun.LUID()) guid, err := luid.GUID() - if err != nil { return nil, fmt.Errorf("failed to get tun GUID: %w", err) } @@ -133,10 +133,10 @@ func (r *windowsRouter) Set(cfg *Config) (retErr error) { for i := 0; i < tries; i++ { found, err := setPrivateNetwork(r.luid) if err != nil { - //networkCategoryWarning.Set(fmt.Errorf("set-network-category: %w", err)) + // networkCategoryWarning.Set(fmt.Errorf("set-network-category: %w", err)) log.Printf("setPrivateNetwork(try=%d): %v", i, err) } else { - //networkCategoryWarning.Set(nil) + // networkCategoryWarning.Set(nil) if found { if i > 0 { log.Printf("setPrivateNetwork(try=%d): success", i) @@ -329,7 +329,7 @@ func (r *windowsRouter) Set(cfg *Config) (retErr error) { ipif6.UseAutomaticMetric = false ipif6.Metric = 0 } - //if mtu > 0 { + // if mtu > 0 { ipif6.NLMTU = uint32(r.mtu) //} ipif6.DadTransmits = 0 @@ -571,6 +571,7 @@ func deltaNets(a, b []netip.Prefix) (add, del []netip.Prefix) { add = append(add, b[j]) j++ default: + // Literally unexpected, since we control the return of the function panic("unexpected compare result") } } @@ -705,6 +706,7 @@ func deltaRouteData(a, b []*routeData) (add, del []*routeData) { add = append(add, b[j]) j++ default: + // Literally unexpected, since we control the return of the function panic("unexpected compare result") } } @@ -878,7 +880,7 @@ func (ft *firewallTweaker) doAsyncSet() { ft.mu.Lock() ft.lastLocal = val - ft.known = (err == nil) + ft.known = err == nil } } diff --git a/usrwg/router/util.go b/usrwg/router/util.go index 6adfc7a..50ebeff 100644 --- a/usrwg/router/util.go +++ b/usrwg/router/util.go @@ -1,15 +1,15 @@ package router import ( - "log" + "fmt" "net/netip" "os/exec" ) -func prefixesToAdd(new, curr []netip.Prefix) (add []netip.Prefix) { - for _, cur := range new { +func prefixesToAdd(newP, currP []netip.Prefix) (add []netip.Prefix) { + for _, cur := range newP { found := false - for _, v := range curr { + for _, v := range currP { found = v == cur if found { break @@ -22,10 +22,10 @@ func prefixesToAdd(new, curr []netip.Prefix) (add []netip.Prefix) { return } -func prefixesToRemove(new, curr []netip.Prefix) (remove []netip.Prefix) { - for _, cur := range curr { +func prefixesToRemove(newP, currP []netip.Prefix) (remove []netip.Prefix) { + for _, cur := range currP { found := false - for _, v := range new { + for _, v := range newP { found = v == cur if found { break @@ -38,6 +38,8 @@ func prefixesToRemove(new, curr []netip.Prefix) (remove []netip.Prefix) { return } +// nolint:unused +// used in router_bsd, golangci-lint on linux trips over it func inet(p netip.Prefix) string { if p.Addr().Is6() { return "inet6" @@ -47,11 +49,14 @@ func inet(p netip.Prefix) string { func cmd(args ...string) *exec.Cmd { if len(args) == 0 { - log.Fatalf("exec.Cmd(%#v) invalid; need argv[0]", args) + // We control this input, and without argv[0] we can't do anything anyways. + panic(fmt.Errorf("exec.Cmd(%#v) invalid; need at least 1 argument", args)) } return exec.Command(args[0], args[1:]...) } +// nolint:unused +// used in router_bsd, golangci-lint on linux trips over it func prefixToSingle(prefix netip.Prefix) netip.Prefix { return netip.PrefixFrom(prefix.Addr(), prefix.Addr().BitLen()) } diff --git a/usrwg/router/winnet/winnet.go b/usrwg/router/winnet/winnet.go index 87ab967..fdfc2a2 100644 --- a/usrwg/router/winnet/winnet.go +++ b/usrwg/router/winnet/winnet.go @@ -16,8 +16,10 @@ import ( const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}" -var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}") -var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}") +var ( + IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}") + IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}") +) type NetworkListManager struct { d *ole.Dispatch @@ -123,7 +125,6 @@ func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) { cl = append(cl, nco) return nil }) - if err != nil { cl.Release() return nil, err diff --git a/usrwg/tun_windows.go b/usrwg/tun_windows.go index d1c37ad..1d09998 100644 --- a/usrwg/tun_windows.go +++ b/usrwg/tun_windows.go @@ -13,6 +13,7 @@ func init() { tun.WintunTunnelType = "ToverSok" guid, err := windows.GUIDFromString("{37217669-42da-4657-a55b-13375d328250}") if err != nil { + // We can create a GUID from a static string without error panic(err) } tun.WintunStaticRequestedGUID = &guid diff --git a/usrwg/wgusp.go b/usrwg/wgusp.go index 9f97916..d062e35 100644 --- a/usrwg/wgusp.go +++ b/usrwg/wgusp.go @@ -2,15 +2,20 @@ package usrwg import ( "fmt" + "log/slog" + "net" + "net/netip" + "slices" + "syscall" + "github.com/edup2p/common/toversok" "github.com/edup2p/common/types" "github.com/edup2p/common/types/key" "github.com/edup2p/common/usrwg/router" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "golang.zx2c4.com/wireguard/device" - "log/slog" - "net/netip" - "runtime" - "strings" + "golang.zx2c4.com/wireguard/tun" ) func init() { @@ -33,8 +38,7 @@ func (u *UserSpaceWireGuardHost) Reset() error { return nil } -const WGGOIPCDevSetup = `private_key=%s -` +const WGGOIPCDevSetup = "private_key=%s\n" func (u *UserSpaceWireGuardHost) Controller(privateKey key.NodePrivate, addr4, addr6 netip.Prefix) (toversok.WireGuardController, error) { if u.running != nil { @@ -46,13 +50,11 @@ func (u *UserSpaceWireGuardHost) Controller(privateKey key.NodePrivate, addr4, a // TODO set this to 1392 per https://docs.eduvpn.org/server/v3/wireguard.html // and make adjustable by environment variable tunDev, err := createTUN(1280) - if err != nil { return nil, fmt.Errorf("failed to create TUN device: %w", err) } r, err := router.NewRouter(tunDev) - if err != nil { return nil, fmt.Errorf("failed to create router: %w", err) } @@ -78,7 +80,9 @@ func (u *UserSpaceWireGuardHost) Controller(privateKey key.NodePrivate, addr4, a nKey := key.UnveilPrivate(privateKey) - wgDev.IpcSet(fmt.Sprintf(WGGOIPCDevSetup, nKey.HexString())) + if err := wgDev.IpcSet(fmt.Sprintf(WGGOIPCDevSetup, nKey.HexString())); err != nil { + return nil, fmt.Errorf("failed to set private key on wireguard device: %w", err) + } if err := wgDev.Up(); err != nil { return nil, fmt.Errorf("failed to bring up wireguard device: %w", err) @@ -98,6 +102,7 @@ func (u *UserSpaceWireGuardHost) Controller(privateKey key.NodePrivate, addr4, a usrwgc := &UserSpaceWireGuardController{ wgDev: wgDev, bind: bind, + tunDev: tunDev, router: r, } @@ -106,51 +111,56 @@ func (u *UserSpaceWireGuardHost) Controller(privateKey key.NodePrivate, addr4, a return usrwgc, nil } -func (u *UserSpaceWireGuardHost) tempPrintInstructions(addr4, addr6 netip.Prefix, name string) { - - const sep = "; " - - switch runtime.GOOS { - case "darwin": - const ( - ifconfig4 = "sudo ifconfig %s inet %s/32 %s" - ifconfig6 = "sudo ifconfig %s inet6 %s %s prefixlen 128" - - route4 = "sudo route add -inet %s -iface %s" - route6 = "sudo route add -inet6 %s -iface %s" - ) - - slog.Warn("Please run these lines in a separate terminal:") - slog.Warn( - strings.Join([]string{ - fmt.Sprintf(ifconfig4, name, addr4.Addr().String(), addr4.Addr().String()), - fmt.Sprintf(ifconfig6, name, addr6.Addr().String(), addr6.Addr().String()), - fmt.Sprintf(route4, addr4.String(), name), - fmt.Sprintf(route6, addr6.String(), name), - }, sep), - ) - case "linux": - const ( - ip = "sudo ip address add %s dev %s" - ) - - slog.Warn("Please run these lines in a separate terminal:") - slog.Warn( - strings.Join([]string{ - fmt.Sprintf(ip, addr4.String(), name), - fmt.Sprintf(ip, addr6.String(), name), - }, sep), - ) - } - -} - type UserSpaceWireGuardController struct { wgDev *device.Device bind *ToverSokBind + tunDev tun.Device router router.Router } +func (u *UserSpaceWireGuardController) Available() bool { + return true +} + +func (u *UserSpaceWireGuardController) InjectPacket(from, to netip.AddrPort, pkt []byte) error { + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + ipv4 := &layers.IPv4{ + Version: 0x4, + TTL: 255, + Protocol: syscall.IPPROTO_UDP, + DstIP: to.Addr().AsSlice(), + SrcIP: from.Addr().AsSlice(), + } + udp := &layers.UDP{ + DstPort: layers.UDPPort(to.Port()), + SrcPort: layers.UDPPort(from.Port()), + } + if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil { + return fmt.Errorf("failed to set udp checksum: %w", err) + } + + err := gopacket.SerializeLayers(buf, opts, + ipv4, + udp, + gopacket.Payload(pkt), + ) + if err != nil { + return fmt.Errorf("failed to serialize packet: %w", err) + } + + packetData := slices.Concat(make([]byte, 16), buf.Bytes()) + + if _, err = u.tunDev.Write([][]byte{packetData}, 16); err != nil { + return fmt.Errorf("failed to inject packet: %w", err) + } + + return nil +} + const WGGOIPCAddPeer = `public_key=%s replace_allowed_ips=true allowed_ip=%s/32 @@ -165,7 +175,6 @@ func (u *UserSpaceWireGuardController) UpdatePeer(publicKey key.NodePublic, cfg publicKey.HexString(), cfg.VIPs.IPv4.String(), cfg.VIPs.IPv6.String(), publicKey.Marshal(), ), ) - if err != nil { err = fmt.Errorf("failed to do IPC set: %w", err) } @@ -181,9 +190,8 @@ func (u *UserSpaceWireGuardController) RemovePeer(publicKey key.NodePublic) erro return nil } -func (u *UserSpaceWireGuardController) GetStats(publicKey key.NodePublic) (*toversok.WGStats, error) { - //TODO implement me - //panic("implement me") +func (u *UserSpaceWireGuardController) GetStats(_ key.NodePublic) (*toversok.WGStats, error) { + // TODO implement me return nil, nil } @@ -192,11 +200,26 @@ func (u *UserSpaceWireGuardController) ConnFor(node key.NodePublic) types.UDPCon return u.bind.GetConn(node) } +func (u *UserSpaceWireGuardController) GetInterface() *net.Interface { + name, err := u.tunDev.Name() + if err != nil { + slog.Warn("failed to get tun device name", "err", err) + return nil + } + i, err := net.InterfaceByName(name) + if err != nil { + slog.Warn("failed to get interface", "name", name, "err", err) + return nil + } + return i +} + func (u *UserSpaceWireGuardController) Close() { + if err := u.bind.Cancel(); err != nil { + slog.Error("Failed to close wireguard bind", "err", err) + } + if err := u.router.Close(); err != nil { + slog.Error("Failed to close router", "err", err) + } u.wgDev.Close() - // TODO return or log error - u.bind.Close() - u.router.Close() } - -//const _ toversok.WireGuardHost = UserspaceWireguardHost{}