From dc4493544838317bc332fe399b92b276739079b2 Mon Sep 17 00:00:00 2001 From: Aleksa Sarai Date: Sun, 30 Jun 2024 16:54:12 +1000 Subject: [PATCH 1/2] join: switch tests to use testify Signed-off-by: Aleksa Sarai --- join_test.go | 189 ++++++++++++++++----------------------------------- 1 file changed, 59 insertions(+), 130 deletions(-) diff --git a/join_test.go b/join_test.go index ef081d6..00ab810 100644 --- a/join_test.go +++ b/join_test.go @@ -6,21 +6,24 @@ package securejoin import ( "errors" + "fmt" "io/ioutil" "os" "path/filepath" "runtime" "syscall" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TODO: These tests won't work on plan9 because it doesn't have symlinks, and // also we use '/' here explicitly which probably won't work on Windows. func symlink(t *testing.T, oldname, newname string) { - if err := os.Symlink(oldname, newname); err != nil { - t.Fatal(err) - } + err := os.Symlink(oldname, newname) + require.NoError(t, err) } type input struct { @@ -30,15 +33,9 @@ type input struct { // Test basic handling of symlink expansion. func TestSymlink(t *testing.T) { - dir, err := ioutil.TempDir("", "TestSymlink") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) symlink(t, "somepath", filepath.Join(dir, "etc")) symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) @@ -65,8 +62,7 @@ func TestSymlink(t *testing.T) { for _, test := range tc { got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + if !assert.NoErrorf(t, err, "securejoin(%q, %q)", test.root, test.unsafe) { continue } // This is only for OS X, where /etc is a symlink to /private/etc. In @@ -77,24 +73,15 @@ func TestSymlink(t *testing.T) { test.expected = expected } } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } + assert.Equalf(t, test.expected, got, "securejoin(%q, %q)", test.root, test.unsafe) } } // In a path without symlinks, SecureJoin is equivalent to Clean+Join. func TestNoSymlink(t *testing.T) { - dir, err := ioutil.TempDir("", "TestNoSymlink") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) tc := []input{ {dir, "somepath", filepath.Join(dir, "somepath")}, @@ -116,26 +103,17 @@ func TestNoSymlink(t *testing.T) { for _, test := range tc { got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + if assert.NoErrorf(t, err, "securejoin(%q, %q)", test.root, test.unsafe) { + assert.Equalf(t, test.expected, got, "securejoin(%q, %q)", test.root, test.unsafe) } } } // Make sure that .. is **not** expanded lexically. func TestNonLexical(t *testing.T) { - dir, err := ioutil.TempDir("", "TestNonLexical") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) os.MkdirAll(filepath.Join(dir, "subdir"), 0755) os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755) @@ -155,28 +133,17 @@ func TestNonLexical(t *testing.T) { {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, } { got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue + if assert.NoErrorf(t, err, "securejoin(%q, %q)", test.root, test.unsafe) { + assert.Equalf(t, test.expected, got, "securejoin(%q, %q)", test.root, test.unsafe) } } } // Make sure that symlink loops result in errors. func TestSymlinkLoop(t *testing.T) { - dir, err := ioutil.TempDir("", "TestSymlinkLoop") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) os.MkdirAll(filepath.Join(dir, "subdir"), 0755) symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link")) @@ -196,41 +163,29 @@ func TestSymlinkLoop(t *testing.T) { {dir, "/../../../../../../../../../../../../../../../../self/.."}, {dir, "/self/././.."}, } { - got, err := SecureJoin(test.root, test.unsafe) - if !errors.Is(err, syscall.ELOOP) { - t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) - continue - } + _, err := SecureJoin(test.root, test.unsafe) + assert.ErrorIsf(t, err, syscall.ELOOP, "securejoin(%q, %q)", test.root, test.unsafe) } } // Make sure that ENOTDIR is correctly handled. func TestEnotdir(t *testing.T) { - dir, err := ioutil.TempDir("", "TestEnotdir") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) os.MkdirAll(filepath.Join(dir, "subdir"), 0755) ioutil.WriteFile(filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0755) symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link")) - for _, test := range []struct { - root, unsafe string - }{ - {dir, "subdir/link"}, - {dir, "notdir"}, - {dir, "notdir/child"}, + for _, test := range []input{ + {dir, "subdir/link", filepath.Join(dir, "notdir/somechild")}, + {dir, "notdir", filepath.Join(dir, "notdir")}, + {dir, "notdir/child", filepath.Join(dir, "notdir/child")}, } { - _, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue + got, err := SecureJoin(test.root, test.unsafe) + if assert.NoErrorf(t, err, "securejoin(%q, %q)", test.root, test.unsafe) { + assert.Equalf(t, test.expected, got, "securejoin(%q, %q)", test.root, test.unsafe) } } } @@ -253,9 +208,7 @@ func TestIsNotExist(t *testing.T) { {errors.New("not a proper error"), false}, } { got := IsNotExist(test.err) - if got != test.expected { - t.Errorf("IsNotExist(%#v): expected %v, got %v", test.err, test.expected, got) - } + assert.Equalf(t, test.expected, got, "IsNotExist(%#v)", test.err) } } @@ -269,15 +222,9 @@ func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(pat // Make sure that SecureJoinVFS actually does use the given VFS interface. func TestSecureJoinVFS(t *testing.T) { - dir, err := ioutil.TempDir("", "TestNonLexical") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) os.MkdirAll(filepath.Join(dir, "subdir"), 0755) os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755) @@ -303,16 +250,9 @@ func TestSecureJoinVFS(t *testing.T) { } got, err := SecureJoinVFS(test.root, test.unsafe, mock) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - if nLstat == 0 && nReadlink == 0 { - t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe) + if assert.NoErrorf(t, err, "securejoin(%q, %q)", test.root, test.unsafe) { + assert.Equalf(t, test.expected, got, "securejoin(%q, %q)", test.root, test.unsafe) + assert.Truef(t, nLstat+nReadlink > 0, "securejoin(%q, %q): expected either lstat or readlink to be called", test.root, test.unsafe) } } } @@ -321,20 +261,14 @@ func TestSecureJoinVFS(t *testing.T) { // that errors are correctly propagated. func TestSecureJoinVFSErrors(t *testing.T) { var ( - lstatErr = errors.New("lstat error") - readlinkErr = errors.New("readlink err") + fakeErr = errors.New("FAKE ERROR") + lstatErr = fmt.Errorf("%w: lstat", fakeErr) + readlinkErr = fmt.Errorf("%w: readlink", fakeErr) ) - // Set up directory. - dir, err := ioutil.TempDir("", "TestSecureJoinVFSErrors") - if err != nil { - t.Fatal(err) - } - dir, err = filepath.EvalSymlinks(dir) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) // Make a link. symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link")) @@ -345,32 +279,32 @@ func TestSecureJoinVFSErrors(t *testing.T) { // Make sure that the set of {lstat, readlink} failures do propagate. for idx, test := range []struct { - vfs VFS - expected []error + vfs VFS + expectErr bool }{ { - expected: []error{nil}, + expectErr: false, vfs: mockVFS{ lstat: os.Lstat, readlink: os.Readlink, }, }, { - expected: []error{lstatErr}, + expectErr: true, vfs: mockVFS{ lstat: lstatFailFn, readlink: os.Readlink, }, }, { - expected: []error{readlinkErr}, + expectErr: true, vfs: mockVFS{ lstat: os.Lstat, readlink: readlinkFailFn, }, }, { - expected: []error{lstatErr, readlinkErr}, + expectErr: true, vfs: mockVFS{ lstat: lstatFailFn, readlink: readlinkFailFn, @@ -378,15 +312,10 @@ func TestSecureJoinVFSErrors(t *testing.T) { }, } { _, err := SecureJoinVFS(dir, "link", test.vfs) - - success := false - for _, exp := range test.expected { - if err == exp { - success = true - } - } - if !success { - t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err) + if test.expectErr { + assert.ErrorIsf(t, err, fakeErr, "SecureJoinVFS.mock%d", idx) + } else { + assert.NoErrorf(t, err, "SecureJoinVFS.mock%d", idx) } } } From 22faec1e9d2c8cf621c752d2b6acaa4aaf2ee5d0 Mon Sep 17 00:00:00 2001 From: Aleksa Sarai Date: Sun, 30 Jun 2024 14:22:31 +1000 Subject: [PATCH 2/2] WIP: port SecureJoin to partialLookupInRoot Signed-off-by: Aleksa Sarai --- join.go => join_generic.go | 30 +------------- join_linux.go | 81 ++++++++++++++++++++++++++++++++++++++ join_linux_test.go | 13 ++++++ join_nonlinux.go | 33 ++++++++++++++++ join_test.go | 2 +- lookup_linux.go | 35 ++++++++++++++-- lookup_linux_test.go | 6 +-- mkdir_linux.go | 2 +- mkdir_linux_test.go | 2 +- open_linux.go | 2 +- openat2_linux.go | 30 ++++++++++++-- 11 files changed, 194 insertions(+), 42 deletions(-) rename join.go => join_generic.go (66%) create mode 100644 join_linux.go create mode 100644 join_linux_test.go create mode 100644 join_nonlinux.go diff --git a/join.go b/join_generic.go similarity index 66% rename from join.go rename to join_generic.go index bd86a48..59de8ea 100644 --- a/join.go +++ b/join_generic.go @@ -29,29 +29,7 @@ func IsNotExist(err error) bool { return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) } -// SecureJoinVFS joins the two given path components (similar to Join) except -// that the returned path is guaranteed to be scoped inside the provided root -// path (when evaluated). Any symbolic links in the path are evaluated with the -// given root treated as the root of the filesystem, similar to a chroot. The -// filesystem state is evaluated through the given VFS interface (if nil, the -// standard os.* family of functions are used). -// -// Note that the guarantees provided by this function only apply if the path -// components in the returned string are not modified (in other words are not -// replaced with symlinks on the filesystem) after this function has returned. -// Such a symlink race is necessarily out-of-scope of SecureJoin. -// -// NOTE: Due to the above limitation, Linux users are strongly encouraged to -// use OpenInRoot instead, which does safely protect against these kinds of -// attacks. There is no way to solve this problem with SecureJoinVFS because -// the API is fundamentally wrong (you cannot return a "safe" path string and -// guarantee it won't be modified afterwards). -// -// Volume names in unsafePath are always discarded, regardless if they are -// provided via direct input or when evaluating symlinks. Therefore: -// -// "C:\Temp" + "D:\path\to\file.txt" results in "C:\Temp\path\to\file.txt" -func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { +func legacySecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { // Use the os.* VFS implementation if none was specified. if vfs == nil { vfs = osVFS{} @@ -122,9 +100,3 @@ func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { finalPath := filepath.Join(string(filepath.Separator), currentPath) return filepath.Join(root, finalPath), nil } - -// SecureJoin is a wrapper around SecureJoinVFS that just uses the os.* library -// of functions as the VFS. If in doubt, use this function over SecureJoinVFS. -func SecureJoin(root, unsafePath string) (string, error) { - return SecureJoinVFS(root, unsafePath, nil) -} diff --git a/join_linux.go b/join_linux.go new file mode 100644 index 0000000..3c40fd2 --- /dev/null +++ b/join_linux.go @@ -0,0 +1,81 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" +) + +func isLexicallyInRoot(root, path string) bool { + if root != "/" { + root += "/" + } + if path != "/" { + path += "/" + } + return strings.HasPrefix(path, root) +} + +// SecureJoin is a wrapper around SecureJoinVFS that just uses the os.* library +// of functions as the VFS. If in doubt, use this function over SecureJoinVFS. +func SecureJoin(root, unsafePath string) (string, error) { + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return "", err + } + defer rootDir.Close() + + handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath, true) + if err != nil { + return "", err + } + defer handle.Close() + + handlePath, err := procSelfFdReadlink(handle) + if err != nil { + return "", fmt.Errorf("verify actual path of %q handle: %w", handle.Name(), err) + } + // Make sure the path is inside the root. + if !isLexicallyInRoot(root, handlePath) { + return "", fmt.Errorf("%w: handle path %q is outside root %q", errPossibleBreakout, handlePath, root) + } + + // remainingPath should be cleaned and safe to append, due to how + // unsafeHallucinateDirectories works. But do an additional cleanup, just + // to be sure. + remainingPath = filepath.Join("/", remainingPath) + return filepath.Join(handlePath, remainingPath), nil +} + +// SecureJoinVFS joins the two given path components (similar to Join) except +// that the returned path is guaranteed to be scoped inside the provided root +// path (when evaluated). Any symbolic links in the path are evaluated with the +// given root treated as the root of the filesystem, similar to a chroot. The +// filesystem state is evaluated through the given VFS interface (if nil, the +// standard os.* family of functions are used). +// +// Note that the guarantees provided by this function only apply if the path +// components in the returned string are not modified (in other words are not +// replaced with symlinks on the filesystem) after this function has returned. +// Such a symlink race is necessarily out-of-scope of SecureJoin. +// +// Volume names in unsafePath are always discarded, regardless if they are +// provided via direct input or when evaluating symlinks. Therefore: +// +// "C:\Temp" + "D:\path\to\file.txt" results in "C:\Temp\path\to\file.txt" +func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { + if vfs == nil || vfs == (osVFS{}) { + return SecureJoin(root, unsafePath) + } + // TODO: Make it possible for partialLookupInRoot to work with VFS. + return legacySecureJoinVFS(root, unsafePath, vfs) +} diff --git a/join_linux_test.go b/join_linux_test.go new file mode 100644 index 0000000..1710b01 --- /dev/null +++ b/join_linux_test.go @@ -0,0 +1,13 @@ +// Copyright (C) 2017-2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "testing" +) + +func TestSymlink(t *testing.T) { + withWithoutOpenat2(t, true, testSymlink) +} diff --git a/join_nonlinux.go b/join_nonlinux.go new file mode 100644 index 0000000..2b0ab02 --- /dev/null +++ b/join_nonlinux.go @@ -0,0 +1,33 @@ +//go:build !linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +// SecureJoin is a wrapper around SecureJoinVFS that just uses the os.* library +// of functions as the VFS. If in doubt, use this function over SecureJoinVFS. +func SecureJoin(root, unsafePath string) (string, error) { + return SecureJoinVFS(root, unsafePath, nil) +} + +// SecureJoinVFS joins the two given path components (similar to Join) except +// that the returned path is guaranteed to be scoped inside the provided root +// path (when evaluated). Any symbolic links in the path are evaluated with the +// given root treated as the root of the filesystem, similar to a chroot. The +// filesystem state is evaluated through the given VFS interface (if nil, the +// standard os.* family of functions are used). +// +// Note that the guarantees provided by this function only apply if the path +// components in the returned string are not modified (in other words are not +// replaced with symlinks on the filesystem) after this function has returned. +// Such a symlink race is necessarily out-of-scope of SecureJoin. +// +// Volume names in unsafePath are always discarded, regardless if they are +// provided via direct input or when evaluating symlinks. Therefore: +// +// "C:\Temp" + "D:\path\to\file.txt" results in "C:\Temp\path\to\file.txt" +func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { + return legacySecureJoinVFS(root, unsafePath, vfs) +} diff --git a/join_test.go b/join_test.go index 00ab810..0c61abf 100644 --- a/join_test.go +++ b/join_test.go @@ -32,7 +32,7 @@ type input struct { } // Test basic handling of symlink expansion. -func TestSymlink(t *testing.T) { +func testSymlink(t *testing.T) { dir := t.TempDir() dir, err := filepath.EvalSymlinks(dir) require.NoError(t, err) diff --git a/lookup_linux.go b/lookup_linux.go index d242e97..dacf5d1 100644 --- a/lookup_linux.go +++ b/lookup_linux.go @@ -153,11 +153,23 @@ func (s *symlinkStack) PopTopSymlink() (*os.File, string, bool) { return tailEntry.dir, tailEntry.remainingPath, true } +const maxUnsafeHallucinateDirectoryTries = 20 + +var errTooManyFakeDirectories = errors.New("encountered too many non-existent paths") + // partialLookupInRoot tries to lookup as much of the request path as possible // within the provided root (a-la RESOLVE_IN_ROOT) and opens the final existing // component of the requested path, returning a file handle to the final // existing component and a string containing the remaining path components. -func partialLookupInRoot(root *os.File, unsafePath string) (_ *os.File, _ string, Err error) { +// +// If unsafeHallucinateDirectories is true, partialLookupInRoot will try to +// emulate the legacy SecureJoin behaviour of treating non-existent paths as +// though they are directories to try to resolve as much of the path as +// possible. In practice, this means that a path like "a/b/doesnotexist/../c" +// will end up being resolved as "a/b/c" if possible. Note that dangling +// symlinks (a symlink that points to a non-existent path) will still result in +// an error being returned, due to how openat2 handles symlinks. +func partialLookupInRoot(root *os.File, unsafePath string, unsafeHallucinateDirectories bool) (_ *os.File, _ string, Err error) { unsafePath = filepath.ToSlash(unsafePath) // noop // This is very similar to SecureJoin, except that we operate on the @@ -166,7 +178,7 @@ func partialLookupInRoot(root *os.File, unsafePath string) (_ *os.File, _ string // Try to use openat2 if possible. if hasOpenat2() { - return partialLookupOpenat2(root, unsafePath) + return partialLookupOpenat2(root, unsafePath, unsafeHallucinateDirectories) } // Get the "actual" root path from /proc/self/fd. This is necessary if the @@ -204,7 +216,9 @@ func partialLookupInRoot(root *os.File, unsafePath string) (_ *os.File, _ string defer symlinkStack.Close() var ( - linksWalked int + linksWalked int + hallucinateDirectoryTries int + currentPath string remainingPath = unsafePath ) @@ -354,6 +368,21 @@ func partialLookupInRoot(root *os.File, unsafePath string) (_ *os.File, _ string _ = currentDir.Close() return oldDir, remainingPath, nil } + // If we were asked to "hallucinate" non-existent paths as though + // they are directories, take the remainingPath and clean it so + // that any ".." components that would lead us back to real paths + // can get resolved. + if oldRemainingPath != "" && unsafeHallucinateDirectories { + if newRemainingPath := path.Clean(oldRemainingPath); newRemainingPath != oldRemainingPath { + hallucinateDirectoryTries++ + if hallucinateDirectoryTries > maxUnsafeHallucinateDirectoryTries { + return nil, "", fmt.Errorf("%w: trying to reconcile non-existent subpath %q", errTooManyFakeDirectories, oldRemainingPath) + } + // Continue the lookup using the new remaining path. + remainingPath = newRemainingPath + continue + } + } // We have hit a final component that doesn't exist, so we have our // partial open result. Note that we have to use the OLD remaining // path, since the lookup failed. diff --git a/lookup_linux_test.go b/lookup_linux_test.go index f755366..9a9266f 100644 --- a/lookup_linux_test.go +++ b/lookup_linux_test.go @@ -19,7 +19,7 @@ import ( "golang.org/x/sys/unix" ) -type partialLookupFunc func(root *os.File, unsafePath string) (*os.File, string, error) +type partialLookupFunc func(root *os.File, unsafePath string, hallucinateDirectoryTries bool) (*os.File, string, error) type lookupResult struct { handlePath, remainingPath string @@ -28,7 +28,7 @@ type lookupResult struct { } func checkPartialLookup(t *testing.T, partialLookupFn partialLookupFunc, rootDir *os.File, unsafePath string, expected lookupResult) { - handle, remainingPath, err := partialLookupFn(rootDir, unsafePath) + handle, remainingPath, err := partialLookupFn(rootDir, unsafePath, false) if handle != nil { defer handle.Close() } @@ -325,7 +325,7 @@ func newRacingLookupMeta(pauseCh chan struct{}) *racingLookupMeta { func (m *racingLookupMeta) checkPartialLookup(t *testing.T, rootDir *os.File, unsafePath string, skipErrs []error, allowedResults []lookupResult) { // Similar to checkPartialLookup, but with extra logic for // handling the lookup stopping partly through the lookup. - handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath) + handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath, false) if err != nil { for _, skipErr := range skipErrs { if errors.Is(err, skipErr) { diff --git a/mkdir_linux.go b/mkdir_linux.go index 05e0bde..989e16a 100644 --- a/mkdir_linux.go +++ b/mkdir_linux.go @@ -48,7 +48,7 @@ func MkdirAllHandle(root *os.File, unsafePath string, mode int) (_ *os.File, Err } // Try to open as much of the path as possible. - currentDir, remainingPath, err := partialLookupInRoot(root, unsafePath) + currentDir, remainingPath, err := partialLookupInRoot(root, unsafePath, false) if err != nil { return nil, fmt.Errorf("find existing subpath of %q: %w", unsafePath, err) } diff --git a/mkdir_linux_test.go b/mkdir_linux_test.go index bdad1c9..c0628ab 100644 --- a/mkdir_linux_test.go +++ b/mkdir_linux_test.go @@ -110,7 +110,7 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa // Before trying to make the tree, figure out what // components don't exist yet so we can check them later. - handle, remainingPath, err := partialLookupInRoot(rootDir, test.unsafePath) + handle, remainingPath, err := partialLookupInRoot(rootDir, test.unsafePath, false) handleName := "" if handle != nil { handleName = handle.Name() diff --git a/open_linux.go b/open_linux.go index 2170061..6bb914a 100644 --- a/open_linux.go +++ b/open_linux.go @@ -16,7 +16,7 @@ import ( // OpenatInRoot is equivalent to OpenInRoot, except that the root is provided // using an *os.File handle, to ensure that the correct root directory is used. func OpenatInRoot(root *os.File, unsafePath string) (*os.File, error) { - handle, remainingPath, err := partialLookupInRoot(root, unsafePath) + handle, remainingPath, err := partialLookupInRoot(root, unsafePath, false) if err != nil { return nil, err } diff --git a/openat2_linux.go b/openat2_linux.go index fc93db8..56b0d7a 100644 --- a/openat2_linux.go +++ b/openat2_linux.go @@ -90,10 +90,16 @@ func openat2File(dir *os.File, path string, how *unix.OpenHow) (*os.File, error) // partialLookupOpenat2 is an alternative implementation of // partialLookupInRoot, using openat2(RESOLVE_IN_ROOT) to more safely get a // handle to the deepest existing child of the requested path within the root. -func partialLookupOpenat2(root *os.File, unsafePath string) (*os.File, string, error) { +func partialLookupOpenat2(root *os.File, unsafePath string, unsafeHallucinateDirectories bool) (*os.File, string, error) { + unsafePath = filepath.ToSlash(unsafePath) // noop + + if !hasOpenat2() { + return nil, "", fmt.Errorf("openat2: %w", unix.ENOTSUP) + } + // TODO: Implement this as a git-bisect-like binary search. - unsafePath = filepath.ToSlash(unsafePath) // noop + var hallucinateDirectoryTries int endIdx := len(unsafePath) for endIdx > 0 { subpath := unsafePath[:endIdx] @@ -108,7 +114,25 @@ func partialLookupOpenat2(root *os.File, unsafePath string) (*os.File, string, e endIdx += 1 } // We found a subpath! - return handle, unsafePath[endIdx:], nil + remainingPath := unsafePath[endIdx:] + // If we were asked to "hallucinate" non-existent paths as though + // they are directories, take the remainingPath and clean it so + // that any ".." components that would lead us back to real paths + // can get resolved. + if remainingPath != "" && unsafeHallucinateDirectories { + if newRemainingPath := filepath.Clean(remainingPath); newRemainingPath != remainingPath { + hallucinateDirectoryTries++ + if hallucinateDirectoryTries > maxUnsafeHallucinateDirectoryTries { + return nil, "", fmt.Errorf("%w: trying to reconcile non-existent subpath %q", errTooManyFakeDirectories, remainingPath) + } + // Start the lookup from the end again using the new + // remaining path. + unsafePath = subpath + "/" + newRemainingPath + endIdx = len(unsafePath) + continue + } + } + return handle, remainingPath, nil } if errors.Is(err, unix.ENOENT) || errors.Is(err, unix.ENOTDIR) { // That path doesn't exist, let's try the next directory up.