From 419aa07b516e7ea716a2b73e384292e9db6fc38e Mon Sep 17 00:00:00 2001 From: Dragan Milic Date: Fri, 13 Sep 2024 09:54:29 +0200 Subject: [PATCH] gvisor test --- .vscode/settings.json | 3 + connect/.vscode/settings.json | 6 + go.mod | 25 +- go.sum | 58 +- ip.go | 164 ++- netstack/egress/conn_id.go | 17 + netstack/egress/egress.go | 165 +++ netstack/tun.go | 1127 +++++++++++++++++++++ packet_transformer.go | 158 +++ packet_transformer_test.go | 55 + pathsource/path_to_source_address.go | 38 + pathsource/path_to_source_address_test.go | 25 + souce_mapper.go | 111 ++ 13 files changed, 1884 insertions(+), 68 deletions(-) create mode 100644 connect/.vscode/settings.json create mode 100644 netstack/egress/conn_id.go create mode 100644 netstack/egress/egress.go create mode 100644 netstack/tun.go create mode 100644 packet_transformer.go create mode 100644 packet_transformer_test.go create mode 100644 pathsource/path_to_source_address.go create mode 100644 pathsource/path_to_source_address_test.go create mode 100644 souce_mapper.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 5e337ed..c2fba27 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,4 +7,7 @@ "-count=1", "-test.short" ], + "cSpell.words": [ + "xxhash" + ] } \ No newline at end of file diff --git a/connect/.vscode/settings.json b/connect/.vscode/settings.json new file mode 100644 index 0000000..2fac987 --- /dev/null +++ b/connect/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "cSpell.words": [ + "pathsource", + "ttlcache" + ] +} diff --git a/go.mod b/go.mod index 9ddbdc5..3f951c8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/urnetwork/connect -go 1.23.0 +go 1.23.1 require ( github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 @@ -16,15 +16,20 @@ require ( github.com/urnetwork/protocol v0.0.0 github.com/urnetwork/userwireguard v0.0.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 - golang.org/x/term v0.20.0 + golang.org/x/term v0.25.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 google.golang.org/protobuf v1.34.2 src.agwa.name/tlshacks v0.0.0-20231008131857-90d701ba3225 ) require ( - golang.org/x/crypto v0.23.0 - golang.org/x/net v0.25.0 + github.com/cespare/xxhash/v2 v2.2.0 + github.com/jellydator/ttlcache/v3 v3.3.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.28.0 + golang.org/x/net v0.30.0 + golang.org/x/sync v0.8.0 + gvisor.dev/gvisor v0.0.0-20241127223613-65a7fdf8cf17 ) require ( @@ -32,8 +37,8 @@ require ( github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/tools v0.21.0 // indirect + golang.org/x/mod v0.21.0 // indirect + golang.org/x/tools v0.26.0 // indirect ) require ( @@ -41,12 +46,14 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/btree v1.1.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect @@ -54,11 +61,13 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/sys v0.26.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bc6cda3..6ec6ea8 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -20,8 +22,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= +github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -38,8 +40,10 @@ github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1 github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4= github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -50,6 +54,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLe github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc= +github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -97,10 +103,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/urnetwork/protocol v0.1.0 h1:XgBEVJ+8K24jVJH37CwkBi/szzh54PWBQaODElto5KQ= -github.com/urnetwork/protocol v0.1.0/go.mod h1:+JJm4mqeK95mTRMAGMeddzdUpPEnZmp2U8RYPjXb/JQ= -github.com/urnetwork/userwireguard v0.0.1 h1:6tG7Oas3Ca1UPjenVmcuYmX87uRIyzc/3hz640wFwsM= -github.com/urnetwork/userwireguard v0.0.1/go.mod h1:4o4/Mpn+ipHHx998wXkRt4yRXZh1bMMl2Ybivh71gLo= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= @@ -108,38 +112,38 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= @@ -150,6 +154,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20241127223613-65a7fdf8cf17 h1:fxrDpTFRhEIuohNVQw6ZbYl794xzsXfPEsKCDYrZq1k= +gvisor.dev/gvisor v0.0.0-20241127223613-65a7fdf8cf17/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= src.agwa.name/tlshacks v0.0.0-20231008131857-90d701ba3225 h1:KvJgNzDBgG6IawXLCenHhjvU7RXQ5UD1a18Nm2ZMyGg= diff --git a/ip.go b/ip.go index ee07f1f..a123cd9 100644 --- a/ip.go +++ b/ip.go @@ -9,6 +9,7 @@ import ( "math" mathrand "math/rand" "net" + "net/netip" "slices" "strconv" "strings" @@ -24,6 +25,8 @@ import ( "github.com/golang/glog" + "github.com/urnetwork/connect/netstack" + "github.com/urnetwork/connect/netstack/egress" "github.com/urnetwork/protocol" ) @@ -206,13 +209,101 @@ func (self *LocalUserNat) receive(source TransferPath, ipProtocol IpProtocol, pa } } +// comparable +type SourceID struct { + ip netip.Addr + port uint16 +} + +type EndpointAddress struct { + realSource netip.AddrPort + transferPath TransferPath +} + func (self *LocalUserNat) Run() { defer self.cancel() udp4Buffer := NewUdp4Buffer(self.ctx, self.receive, self.settings.UdpBufferSettings) udp6Buffer := NewUdp6Buffer(self.ctx, self.receive, self.settings.UdpBufferSettings) - tcp4Buffer := NewTcp4Buffer(self.ctx, self.receive, self.settings.TcpBufferSettings) - tcp6Buffer := NewTcp6Buffer(self.ctx, self.receive, self.settings.TcpBufferSettings) + // tcp4Buffer := NewTcp4Buffer(self.ctx, self.receive, self.settings.TcpBufferSettings) + // tcp6Buffer := NewTcp6Buffer(self.ctx, self.receive, self.settings.TcpBufferSettings) + + sourceMap := map[SourceID]EndpointAddress{} + sourceMapLock := sync.Mutex{} + + packetTransformer := NewPacketTransformer(self.ctx) + + dev, tnet, err := netstack.CreateNetTUN(nil, nil, 1500) + if err != nil { + glog.Infof("[lnr]error = %s\n", err) + return + } + + eg := egress.NewEgress(dev, tnet) + + go func() { + buffer := make([]byte, 2000) + for self.ctx.Err() == nil { + n, err := eg.Read(buffer) + if err != nil { + glog.Infof("[lnr]read error = %s\n", err) + return + } + + if n == 0 { + return + } + + outPacket := buffer[0:n] + + ipVersion := uint8(buffer[0]) >> 4 + + tp := TransferPath{} + + switch ipVersion { + case 4: + ipv4 := layers.IPv4{} + ipv4.DecodeFromBytes(buffer[0:n], gopacket.NilDecodeFeedback) + switch ipv4.Protocol { + case layers.IPProtocolTCP: + + pkt, pth, err := packetTransformer.RewritePacketToVPN(buffer[0:n]) + if err != nil { + glog.Infof("[lnr]rewrite error = %s\n", err) + continue + } + + tp = *pth + outPacket = pkt + + } + + case 6: + ipv6 := layers.IPv6{} + ipv6.DecodeFromBytes(buffer[0:n], gopacket.NilDecodeFeedback) + switch ipv6.NextHeader { + case layers.IPProtocolTCP: + tcp := layers.TCP{} + tcp.DecodeFromBytes(ipv6.Payload, gopacket.NilDecodeFeedback) + + sourceId := SourceID{ + ip: netip.AddrFrom16([16]byte(ipv6.DstIP)), + port: uint16(tcp.DstPort), + } + + sourceMapLock.Lock() + epa := sourceMap[sourceId] + sourceMapLock.Unlock() + + tp = epa.transferPath + + } + + } + + self.receive(tp, IpProtocolTcp, outPacket) + } + }() for { select { @@ -249,26 +340,33 @@ func (self *LocalUserNat) Run() { c() } case layers.IPProtocolTCP: + tcp := layers.TCP{} tcp.DecodeFromBytes(ipv4.Payload, gopacket.NilDecodeFeedback) - c := func() bool { - success, err := tcp4Buffer.send( - sendPacket.source, - sendPacket.provideMode, - &ipv4, - &tcp, - self.settings.BufferTimeout, - ) - return success && err == nil + sourceId := SourceID{ + ip: netip.AddrFrom4([4]byte(ipv4.SrcIP)), + port: uint16(tcp.SrcPort), } - if glog.V(2) { - TraceWithReturn( - fmt.Sprintf("[lnr]send tcp4 %s<-%s s(%s)", self.clientTag, sendPacket.source.SourceId, sendPacket.source.StreamId), - c, - ) - } else { - c() + + sourceMapLock.Lock() + sourceMap[sourceId] = EndpointAddress{ + realSource: netip.AddrPortFrom(netip.AddrFrom4([4]byte(ipv4.SrcIP)), uint16(tcp.SrcPort)), + transferPath: sendPacket.source, + } + sourceMapLock.Unlock() + + //TODO: rewrite the source address + + rewritten, err := packetTransformer.RewritePacketFromVPN(ipPacket, sendPacket.source) + if err != nil { + glog.Infof("[lnr]rewrite error = %s\n", err) + continue + } + + _, err = eg.Write(rewritten) + if err != nil { + glog.Infof("[lnr]write error = %s\n", err) } default: // no support for this protocol, drop @@ -303,23 +401,21 @@ func (self *LocalUserNat) Run() { tcp := layers.TCP{} tcp.DecodeFromBytes(ipv6.Payload, gopacket.NilDecodeFeedback) - c := func() bool { - success, err := tcp6Buffer.send( - sendPacket.source, - sendPacket.provideMode, - &ipv6, - &tcp, - self.settings.BufferTimeout, - ) - return success && err == nil + sourceId := SourceID{ + ip: netip.AddrFrom16([16]byte(ipv6.SrcIP)), + port: uint16(tcp.SrcPort), } - if glog.V(2) { - TraceWithReturn( - fmt.Sprintf("[lnr]send tcp6 %s<-%s s(%s)", self.clientTag, sendPacket.source.SourceId, sendPacket.source.StreamId), - c, - ) - } else { - c() + + sourceMapLock.Lock() + sourceMap[sourceId] = EndpointAddress{ + realSource: netip.AddrPortFrom(netip.AddrFrom16([16]byte(ipv6.SrcIP)), uint16(tcp.SrcPort)), + transferPath: sendPacket.source, + } + sourceMapLock.Unlock() + + _, err = dev.Write(ipPacket) + if err != nil { + glog.Infof("[lnr]write error = %s\n", err) } default: // no support for this protocol, drop diff --git a/netstack/egress/conn_id.go b/netstack/egress/conn_id.go new file mode 100644 index 0000000..6676112 --- /dev/null +++ b/netstack/egress/conn_id.go @@ -0,0 +1,17 @@ +package egress + +import ( + "fmt" + "net/netip" +) + +type ConnID struct { + SourceIP netip.Addr + SourcePort uint16 + DestIP netip.Addr + DestPort uint16 +} + +func (c ConnID) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", c.SourceIP, c.SourcePort, c.DestIP, c.DestPort) +} diff --git a/netstack/egress/egress.go b/netstack/egress/egress.go new file mode 100644 index 0000000..33c1c06 --- /dev/null +++ b/netstack/egress/egress.go @@ -0,0 +1,165 @@ +package egress + +import ( + "fmt" + "io" + "net" + "net/netip" + "sync" + + "github.com/golang/glog" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/urnetwork/connect/netstack" + "golang.org/x/sync/errgroup" +) + +type Egress struct { + dev netstack.Device + net *netstack.Net + runningListeners map[uint16]func() + mu sync.Mutex +} + +func NewEgress(dev netstack.Device, net *netstack.Net) *Egress { + e := &Egress{ + dev: dev, + net: net, + runningListeners: make(map[uint16]func()), + } + return e +} + +func (e *Egress) Write(pkt []byte) (int, error) { + + cid, ok := syncPacketData(pkt) + if ok { + fmt.Println("SYN packet detected", cid) + } + + if ok { + e.mu.Lock() + + registerFunc, ok := e.runningListeners[cid.DestPort] + if !ok { + registerFunc = sync.OnceFunc(func() { + fmt.Println("registering listener", cid.DestPort) + defer func() { + fmt.Println("listener registered", cid.DestPort) + }() + + list, err := e.net.ListenTCP(&net.TCPAddr{ + IP: net.IPv4zero, + Port: int(cid.DestPort), + }) + if err != nil { + glog.Error("failed to listen", err) + } + + go func() { + for { + c, err := list.Accept() + if err != nil { + glog.Error("failed to accept", err) + return + } + go func() { + defer c.Close() + local := c.LocalAddr() + + addr := local.(*net.TCPAddr) + oc, err := net.DialTCP("tcp", nil, addr) + if err != nil { + glog.Error("failed to dial", err) + return + } + defer oc.Close() + + eg := errgroup.Group{} + eg.Go(func() error { + defer c.Close() + defer oc.Close() + _, err := io.Copy(oc, c) + return err + }) + + eg.Go(func() error { + defer c.Close() + defer oc.Close() + _, err := io.Copy(c, oc) + return err + }) + + err = eg.Wait() + if err != nil { + glog.Error("failed to copy", err) + } + + }() + } + }() + }) + e.runningListeners[cid.DestPort] = registerFunc + } + + e.mu.Unlock() + + registerFunc() + } + + return e.dev.Write(pkt) + +} + +func (e *Egress) Read(pkt []byte) (int, error) { + return e.dev.Read(pkt) +} + +func syncPacketData(packet []byte) (ConnID, bool) { + switch packet[0] >> 4 { + case 4: + pk := gopacket.NewPacket(packet, layers.LayerTypeIPv4, gopacket.NoCopy) + v4Layer, _ := pk.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v4Layer == nil { + return ConnID{}, false + } + tcpLayer := pk.Layer(layers.LayerTypeTCP).(*layers.TCP) + if tcpLayer == nil { + return ConnID{}, false + } + + if tcpLayer.SYN && !tcpLayer.ACK { + return ConnID{ + SourceIP: netip.AddrFrom4([4]byte(v4Layer.SrcIP)), + SourcePort: uint16(tcpLayer.SrcPort), + DestIP: netip.AddrFrom4([4]byte(v4Layer.DstIP)), + DestPort: uint16(tcpLayer.DstPort), + }, true + + } + + case 6: + pk := gopacket.NewPacket(packet, layers.LayerTypeIPv6, gopacket.NoCopy) + v6Layer, _ := pk.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6Layer == nil { + return ConnID{}, false + } + tcpLayer := pk.Layer(layers.LayerTypeTCP).(*layers.TCP) + if tcpLayer == nil { + return ConnID{}, false + } + + if tcpLayer.SYN && !tcpLayer.ACK { + return ConnID{ + SourceIP: netip.AddrFrom16([16]byte(v6Layer.SrcIP)), + SourcePort: uint16(tcpLayer.SrcPort), + DestIP: netip.AddrFrom16([16]byte(v6Layer.DstIP)), + DestPort: uint16(tcpLayer.DstPort), + }, true + } + + } + + return ConnID{}, false + +} diff --git a/netstack/tun.go b/netstack/tun.go new file mode 100644 index 0000000..29ce202 --- /dev/null +++ b/netstack/tun.go @@ -0,0 +1,1127 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package netstack + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "regexp" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "github.com/golang/glog" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + + "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool + mu sync.Mutex + registeredAddresses map[netip.Addr]bool +} + +type Net netTun + +type Device interface { + Read(buf []byte) (n int, err error) + + Write(buf []byte) (int, error) + + Close() error +} + +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (Device, *Net, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + incomingPacket: make(chan *buffer.View), + dnsServers: dnsServers, + mtu: mtu, + registeredAddresses: make(map[netip.Addr]bool), + } + dev.ep.AddNotify(dev) + tcpipErr := dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + + // Set the TCP receive and send buffer sizes to 4MB. + err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 4 << 20, + Max: 4 << 20, + Default: 4 << 20, + }) + + if err != nil { + return nil, nil, fmt.Errorf("TCPReceiveBufferSizeRangeOption failed: %v", err) + } + + err = dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{ + Min: 4 << 20, + Max: 4 << 20, + Default: 4 << 20, + }) + + if err != nil { + return nil, nil, fmt.Errorf("TCPSendBufferSizeRangeOption failed: %v", err) + } + + delayEnabled := tcpip.TCPDelayEnabled(true) + + err = dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &delayEnabled) + if err != nil { + return nil, nil, fmt.Errorf("TCPDelayEnabled failed: %v", err) + } + + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + dev.registeredAddresses[ip] = true + } + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + + return dev, (*Net)(dev), nil +} + +func (tun *netTun) Read(buf []byte) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf) + if err != nil { + return 0, err + } + return n, nil +} + +func (tun *netTun) Write(buf []byte) (int, error) { + packet := buf + if len(packet) == 0 { + return 0, nil + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + + switch packet[0] >> 4 { + case 4: + packet := gopacket.NewPacket(packet, layers.LayerTypeIPv4, gopacket.NoCopy) + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + if ipLayer != nil { + v4, _ := ipLayer.(*layers.IPv4) + + addr := netip.AddrFrom4([4]byte(v4.DstIP)) + + tun.mu.Lock() + registered := tun.registeredAddresses[addr] + tun.mu.Unlock() + + if !registered { + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFromSlice(v4.DstIP).WithPrefix(), + } + + tcpipErr := tun.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return 0, fmt.Errorf("AddProtocolAddress(%v): %v", v4.DstIP, tcpipErr) + } + + tun.mu.Lock() + tun.registeredAddresses[addr] = true + tun.mu.Unlock() + + glog.Info("Added protocol address: ", addr) + } + + } + + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + packet := gopacket.NewPacket(packet, layers.LayerTypeIPv6, gopacket.NoCopy) + + ipLayer := packet.Layer(layers.LayerTypeIPv6) + if ipLayer != nil { + v6, _ := ipLayer.(*layers.IPv6) + + addr := netip.AddrFrom16([16]byte(v6.DstIP)) + + tun.mu.Lock() + registered := tun.registeredAddresses[addr] + tun.mu.Unlock() + + if !registered { + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFromSlice(v6.DstIP).WithPrefix(), + } + + tcpipErr := tun.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return 0, fmt.Errorf("AddProtocolAddress(%v): %v", v6.DstIP, tcpipErr) + } + + tun.mu.Lock() + tun.registeredAddresses[addr] = true + tun.mu.Unlock() + } + + } + + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + tun.ep.Close() + + tun.stack.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *netTun) BatchSize() int { + return 1 +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) +} + +func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) +} + +func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { + if addr == nil { + return net.ListenTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} + +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + ip, _ := netip.AddrFromSlice(laddr.IP) + la = netip.AddrPortFrom(ip, uint16(laddr.Port)) + } + if raddr != nil { + ip, _ := netip.AddrFromSlice(raddr.IP) + ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + deadline *time.Timer +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (ia PingAddr) Addr() netip.Addr { + return ia.addr +} + +func PingAddrFromAddr(addr netip.Addr) *PingAddr { + return &PingAddr{addr} +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + if !laddr.IsValid() && !raddr.IsValid() { + return nil, errors.New("ping dial: invalid address") + } + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{ + laddr: PingAddr{laddr}, + deadline: time.NewTimer(time.Hour << 10), + } + pc.deadline.Stop() + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { + return net.DialPingAddr(laddr, netip.Addr{}) +} + +func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { + var la, ra netip.Addr + if laddr != nil { + la = laddr.addr + } + if raddr != nil { + ra = raddr.addr + } + return net.DialPingAddr(la, ra) +} + +func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { + var la netip.Addr + if laddr != nil { + la = laddr.addr + } + return net.ListenPingAddr(la) +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + pc.deadline.Reset(0) + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var na netip.Addr + switch v := addr.(type) { + case *PingAddr: + na = v.addr + case *net.IPAddr: + na, _ = netip.AddrFromSlice(v.IP) + default: + return 0, fmt.Errorf("ping write: wrong net.Addr type") + } + if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + buf := bytes.NewReader(p) + rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, &pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) + defer pc.wq.EventUnregister(&e) + + select { + case <-pc.deadline.C: + return 0, nil, os.ErrDeadlineExceeded + case <-notifyCh: + } + + w := tcpip.SliceWriter(p) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) + return res.Count, &PingAddr{remoteAddr}, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.deadline.Reset(time.Until(t)) + return nil +} + +var ( + errNoSuchHost = errors.New("no such host") + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") +) + +func (net *Net) LookupHost(host string) (addrs []string, err error) { + return net.LookupContextHost(context.Background(), host) +} + +func isDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} + +func randU16() uint16 { + var b [2]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.LittleEndian.Uint16(b[:]) +} + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = randU16() + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func equalASCIIName(x, y dnsmessage.Name) bool { + if x.Length != y.Length { + return false + } + for i := 0; i < int(x.Length); i++ { + a := x.Data[i] + b := y.Data[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 512) + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 1280) + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + + for _, useUDP := range []bool{true, false} { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + var c net.Conn + var err error + if useUDP { + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) + } else { + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) + } + + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + err := c.SetDeadline(d) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + } + var p dnsmessage.Parser + var h dnsmessage.Header + if useUDP { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < 2; i++ { + for _, server := range tnet.dnsServers { + p, h, err := tnet.exchange(ctx, server, q, time.Second*5) + if err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + if _, ok := err.(*net.OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + dnsErr.IsNotFound = true + return p, server.String(), dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server.String(), nil + } + lastErr = &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errNoSuchHost { + lastErr.(*net.DNSError).IsNotFound = true + return p, server.String(), lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { + if host == "" || (!tnet.hasV6 && !tnet.hasV4) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + zlen := len(host) + if strings.IndexByte(host, ':') != -1 { + if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { + zlen = zidx + } + } + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil + } + + if !isDomainName(host) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + type result struct { + p dnsmessage.Parser + server string + error + } + var addrsV4, addrsV6 []netip.Addr + lanes := 0 + if tnet.hasV4 { + lanes++ + } + if tnet.hasV6 { + lanes++ + } + lane := make(chan result, lanes) + var lastErr error + if tnet.hasV4 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) + lane <- result{p, server, err} + }() + } + if tnet.hasV6 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) + lane <- result{p, server, err} + }() + } + for l := 0; l < lanes; l++ { + result := <-lane + if result.error != nil { + if lastErr == nil { + lastErr = result.error + } + continue + } + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + continue + } + } + } + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled + var addrs []netip.Addr + if tnet.hasV6 { + addrs = append(addrsV6, addrsV4...) + } else { + addrs = append(addrsV4, addrsV6...) + } + + if len(addrs) == 0 && lastErr != nil { + return nil, lastErr + } + saddrs := make([]string, 0, len(addrs)) + for _, ip := range addrs { + saddrs = append(saddrs, ip.String()) + } + return saddrs, nil +} + +func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { + if deadline.IsZero() { + return deadline, nil + } + timeRemaining := deadline.Sub(now) + if timeRemaining <= 0 { + return time.Time{}, errTimeout + } + timeout := timeRemaining / time.Duration(addrsRemaining) + const saneMinimum = 2 * time.Second + if timeout < saneMinimum { + if timeRemaining < saneMinimum { + timeout = timeRemaining + } else { + timeout = saneMinimum + } + } + return now.Add(timeout), nil +} + +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + +func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if ctx == nil { + panic("nil context") + } + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { + acceptV4 = true + acceptV6 = true + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 + } + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } + } + allAddr, err := tnet.LookupContextHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + var addrs []netip.AddrPort + for _, addr := range allAddr { + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} + } + + var firstErr error + for i, addr := range addrs { + select { + case <-ctx.Done(): + err := ctx.Err() + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return nil, &net.OpError{Op: "dial", Err: err} + default: + } + + dialCtx := ctx + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) + if err != nil { + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: err} + } + break + } + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() + } + } + + var c net.Conn + switch matches[1] { + case "tcp": + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) + } + if err == nil { + return c, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} + } + return nil, firstErr +} + +func (tnet *Net) Dial(network, address string) (net.Conn, error) { + return tnet.DialContext(context.Background(), network, address) +} diff --git a/packet_transformer.go b/packet_transformer.go new file mode 100644 index 0000000..acd381e --- /dev/null +++ b/packet_transformer.go @@ -0,0 +1,158 @@ +package connect + +import ( + "context" + "fmt" + "net/netip" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +type PacketTransformer struct { + sm *SourceMapper +} + +// NewPacketTransformer creates a new PacketTransformer. +func NewPacketTransformer(ctx context.Context) *PacketTransformer { + + return &PacketTransformer{ + sm: NewSourceMapper(ctx), + } +} + +func (p *PacketTransformer) RewritePacketFromVPN(buffer []byte, tp TransferPath) ([]byte, error) { + + ipVersion := uint8(buffer[0]) >> 4 + + switch ipVersion { + case 4: + + // gopacket.NewPacket(buffer, layers.LayerTypeIPv4, gopacket.Default) + + ipv4 := &layers.IPv4{} + err := ipv4.DecodeFromBytes(buffer, gopacket.NilDecodeFeedback) + if err != nil { + return nil, fmt.Errorf("failed to decode IPv4 layer: %w", err) + } + + switch ipv4.Protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{} + err := tcp.DecodeFromBytes(ipv4.Payload, gopacket.NilDecodeFeedback) + if err != nil { + return nil, fmt.Errorf("failed to decode TCP layer: %w", err) + } + + ipv4.Checksum = 0 + tcp.SetNetworkLayerForChecksum(ipv4) + tcp.Checksum = 0 + + mapped := p.sm.GetSourceAddressMapping(netip.AddrFrom4([4]byte(ipv4.SrcIP)), tp) + + ipv4.SrcIP = mapped.AsSlice() + + rewrittenBuffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + err = gopacket.SerializeLayers(rewrittenBuffer, options, + ipv4, + tcp, + gopacket.Payload(tcp.LayerPayload()), + ) + if err != nil { + return nil, fmt.Errorf("failed to serialize packet: %w", err) + } + + return rewrittenBuffer.Bytes(), nil + + } + + return buffer, nil + + // case 6: + // ipv6 := layers.IPv6{} + // ipv6.DecodeFromBytes(buffer[0:n], gopacket.NilDecodeFeedback) + // switch ipv6.NextHeader { + // case layers.IPProtocolTCP: + // tcp := layers.TCP{} + // tcp.DecodeFromBytes(ipv6.Payload, gopacket.NilDecodeFeedback) + + // sourceId := SourceID{ + // ip: netip.AddrFrom16([16]byte(ipv6.DstIP)), + // port: uint16(tcp.DstPort), + // } + + // sourceMapLock.Lock() + // epa := sourceMap[sourceId] + // sourceMapLock.Unlock() + + // tp = epa.transferPath + + // } + + default: + return buffer, nil + } +} + +func (p *PacketTransformer) RewritePacketToVPN(buffer []byte) ([]byte, *TransferPath, error) { + ipVersion := uint8(buffer[0]) >> 4 + + switch ipVersion { + case 4: + + // gopacket.NewPacket(buffer, layers.LayerTypeIPv4, gopacket.Default) + + ipv4 := &layers.IPv4{} + err := ipv4.DecodeFromBytes(buffer, gopacket.NilDecodeFeedback) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode IPv4 layer: %w", err) + } + + switch ipv4.Protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{} + err := tcp.DecodeFromBytes(ipv4.Payload, gopacket.NilDecodeFeedback) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode TCP layer: %w", err) + } + + ipv4.Checksum = 0 + tcp.SetNetworkLayerForChecksum(ipv4) + tcp.Checksum = 0 + + realEndpointAddress, ok := p.sm.GetRealEndpointAddress(netip.AddrFrom4([4]byte(ipv4.DstIP))) + if !ok { + return nil, nil, fmt.Errorf("failed to find real address mapping for %s", netip.AddrFrom4([4]byte(ipv4.DstIP))) + } + + ipv4.DstIP = realEndpointAddress.realSource.AsSlice() + + rewrittenBuffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + err = gopacket.SerializeLayers(rewrittenBuffer, options, + ipv4, + tcp, + gopacket.Payload(tcp.LayerPayload()), + ) + if err != nil { + return nil, nil, fmt.Errorf("failed to serialize packet: %w", err) + } + + return rewrittenBuffer.Bytes(), &realEndpointAddress.transferPath, nil + + } + + } + + return nil, nil, fmt.Errorf("unsupported ip version: %d", ipVersion) + +} diff --git a/packet_transformer_test.go b/packet_transformer_test.go new file mode 100644 index 0000000..6717031 --- /dev/null +++ b/packet_transformer_test.go @@ -0,0 +1,55 @@ +package connect_test + +import ( + "context" + "testing" + + "github.com/go-playground/assert/v2" + "github.com/stretchr/testify/require" + "github.com/urnetwork/connect" +) + +func TestPacketTransformer(t *testing.T) { + + t.Run("RewritePacketFromVPN and RewritePacketToVPN", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + originalFromVPN := []byte{ + 0x45, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x01, + 0x7f, 0x00, 0x00, 0x01, 0xe3, 0xa5, 0x1f, 0x90, 0x6b, 0x9d, 0x68, 0x93, 0x00, 0x00, 0x00, 0x00, + 0xb0, 0x02, 0xff, 0xff, 0xfe, 0x34, 0x00, 0x00, 0x02, 0x04, 0x3f, 0xd8, 0x01, 0x03, 0x03, 0x06, + 0x01, 0x01, 0x08, 0x0a, 0x29, 0x02, 0xd5, 0x4c, 0x00, 0x00, 0x00, 0x00, 0x04, 0x02, 0x00, 0x00, + } + + pt := connect.NewPacketTransformer(ctx) + rewritten, err := pt.RewritePacketFromVPN(originalFromVPN, connect.TransferPath{}) + assert.Equal(t, nil, err) + + expectedFromVPN := []byte{ + 0x45, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xbb, 0xb6, 0x00, 0x00, 0x00, 0x01, + 0x7f, 0x00, 0x00, 0x01, 0xe3, 0xa5, 0x1f, 0x90, 0x6b, 0x9d, 0x68, 0x93, 0x00, 0x00, 0x00, 0x00, + 0xb0, 0x02, 0xff, 0xff, 0xa8, 0x20, 0x00, 0x00, 0x02, 0x04, 0x3f, 0xd8, 0x01, 0x03, 0x03, 0x06, + 0x01, 0x01, 0x08, 0x0a, 0x29, 0x02, 0xd5, 0x4c, 0x00, 0x00, 0x00, 0x00, 0x04, 0x02, 0x00, 0x00, + } + + require.Equal(t, expectedFromVPN, rewritten) + + originalToVPN := []byte{ + 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x01, 0x1f, 0x90, 0xe3, 0xa5, 0x00, 0x00, 0x00, 0x00, 0x6b, 0x9d, 0x68, 0x94, + 0x50, 0x14, 0x00, 0x00, 0xfe, 0x1c, 0x00, 0x00, + } + + rewrittenToVPN, _, err := pt.RewritePacketToVPN(originalToVPN) + require.NoError(t, err) + + expectedToVPN := []byte{ + 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x3c, 0xce, 0x7f, 0x00, 0x00, 0x01, + 0x7f, 0x00, 0x00, 0x01, 0x1f, 0x90, 0xe3, 0xa5, 0x00, 0x00, 0x00, 0x00, 0x6b, 0x9d, 0x68, 0x94, + 0x50, 0x14, 0x00, 0x00, 0xda, 0x66, 0x00, 0x00, + } + + require.Equal(t, expectedToVPN, rewrittenToVPN) + + }) +} diff --git a/pathsource/path_to_source_address.go b/pathsource/path_to_source_address.go new file mode 100644 index 0000000..fcacb5e --- /dev/null +++ b/pathsource/path_to_source_address.go @@ -0,0 +1,38 @@ +package pathsource + +import ( + "encoding/binary" + "net" + "net/netip" + + "github.com/cespare/xxhash/v2" + "github.com/urnetwork/connect" +) + +// PathToSourceAddress generates a new source address and port based on the given source address and port +// and the transfer path. +// The new source address and port are deterministic and unique for the given transfer path. +// Chances of collision are very low (2^-36). +func PathToSourceAddress(tp *connect.TransferPath, sourceAddress net.IP, sourcePort uint16) netip.AddrPort { + xxhash := xxhash.New() + xxhash.Write(tp.SourceId[:]) + xxhash.Write(tp.DestinationId[:]) + xxhash.Write(tp.StreamId[:]) + xxhash.Write([]byte(sourceAddress)) + xxhash.Write(binary.BigEndian.AppendUint16(nil, sourcePort)) + + hash := xxhash.Sum(nil) + + firstByte := byte(hash[0]) + + // make sure the first byte is not in the multicast range + if firstByte&0xe0 == 0xe0 { + firstByte = ^byte(0x10) + } + + sourceIP := netip.AddrFrom4([4]byte{firstByte, hash[1], hash[2], hash[3]}) + + sourcePort = binary.BigEndian.Uint16(hash[4:]) + + return netip.AddrPortFrom(sourceIP, sourcePort) +} diff --git a/pathsource/path_to_source_address_test.go b/pathsource/path_to_source_address_test.go new file mode 100644 index 0000000..39d6c32 --- /dev/null +++ b/pathsource/path_to_source_address_test.go @@ -0,0 +1,25 @@ +package pathsource_test + +import ( + "net/netip" + "testing" + + "github.com/go-playground/assert/v2" + "github.com/urnetwork/connect" + "github.com/urnetwork/connect/pathsource" +) + +func TestPathToSourceAddress(t *testing.T) { + pth := &connect.TransferPath{ + SourceId: connect.Id{}, + DestinationId: connect.Id{}, + StreamId: connect.Id{}, + } + + sourceAddress := []byte{192, 168, 1, 1} + sourcePort := uint16(1234) + + ap := pathsource.PathToSourceAddress(pth, sourceAddress, sourcePort) + assert.Equal(t, ap, netip.AddrPortFrom(netip.AddrFrom4([4]byte{181, 92, 103, 96}), 7285)) + +} diff --git a/souce_mapper.go b/souce_mapper.go new file mode 100644 index 0000000..91bd062 --- /dev/null +++ b/souce_mapper.go @@ -0,0 +1,111 @@ +package connect + +import ( + "context" + "fmt" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/jellydator/ttlcache/v3" +) + +type RealEndpointAddress struct { + realSource netip.Addr + transferPath TransferPath +} + +type SourceMapper struct { + lastAddress *atomic.Uint64 + cache *ttlcache.Cache[netip.Addr, RealEndpointAddress] + forwardMapping map[TransferPath]netip.Addr + mappingLock *sync.RWMutex +} + +// NewSourceMapper creates a new SourceMapper. +// The SourceMapper is used to map the source address of a packet to a new source address. +func NewSourceMapper(ctx context.Context) *SourceMapper { + cache := ttlcache.New[netip.Addr, RealEndpointAddress]( + // if no packet is received within 5 minutes, the mapping will be removed. + ttlcache.WithTTL[netip.Addr, RealEndpointAddress](5*time.Minute), + // the cache will be able to store up to 20000 mappings. + ttlcache.WithCapacity[netip.Addr, RealEndpointAddress](20_000), + ) + + mu := new(sync.RWMutex) + forwardMapping := make(map[TransferPath]netip.Addr) + + cache.OnEviction(func(ctx context.Context, er ttlcache.EvictionReason, i *ttlcache.Item[netip.Addr, RealEndpointAddress]) { + mu.Lock() + defer mu.Unlock() + + // delete the forwardMapping + delete(forwardMapping, i.Value().transferPath) + }) + + cache.OnInsertion(func(ctx context.Context, i *ttlcache.Item[netip.Addr, RealEndpointAddress]) { + mu.Lock() + defer mu.Unlock() + + // add the forwardMapping + forwardMapping[i.Value().transferPath] = i.Key() + }) + + go cache.Start() + + go func() { + <-ctx.Done() + cache.Stop() + }() + + return &SourceMapper{ + cache: cache, + lastAddress: new(atomic.Uint64), + mappingLock: mu, + forwardMapping: forwardMapping, + } +} + +// GetSourceAddressMapping returns the mapped address for the given source address. +// If the source address is not already mapped, a new address will be generated and returned. +func (sm *SourceMapper) GetSourceAddressMapping(sourceAddress netip.Addr, tp TransferPath) netip.Addr { + + sm.mappingLock.RLock() + mappedAddress, found := sm.forwardMapping[tp] + sm.mappingLock.RUnlock() + if found { + return mappedAddress + } + + nextAddressInt := sm.lastAddress.Add(1) + + nextAddress := netip.AddrFrom4([4]byte{ + byte(nextAddressInt >> 24), + byte(nextAddressInt >> 16), + byte(nextAddressInt >> 8), + byte(nextAddressInt), + }) + + fmt.Println("new mapping", sourceAddress, nextAddress, tp) + sm.cache.Set( + nextAddress, + RealEndpointAddress{ + realSource: sourceAddress, + transferPath: tp, + }, + 0, + ) + + return nextAddress +} + +func (sm *SourceMapper) GetRealEndpointAddress(sourceAddress netip.Addr) (RealEndpointAddress, bool) { + v := sm.cache.Get(sourceAddress) + + if v == nil { + return RealEndpointAddress{}, false + } + + return v.Value(), true +}