From eb33a0471e58c9d3d1246ccedb62135fd448d109 Mon Sep 17 00:00:00 2001 From: Ronald G Minnich Date: Tue, 3 Jun 2025 09:41:20 -0700 Subject: [PATCH] Plan 9 support Since Plan 9 does not have symlinks, these problems do not occur. Therefore, SecureJoinVFS and SecureJoin can map to filepath.Join, along with the test for rootpath containing .. Split tests so some can run on Plan 9 Move common variables and functions into common.go Signed-off-by: Ronald G Minnich --- common.go | 48 +++++++++ join.go | 37 +------ join_plan9.go | 27 +++++ join_test.go | 259 +-------------------------------------------- symlink_test.go | 276 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 354 insertions(+), 293 deletions(-) create mode 100644 common.go create mode 100644 join_plan9.go create mode 100644 symlink_test.go diff --git a/common.go b/common.go new file mode 100644 index 0000000..8f5e2c0 --- /dev/null +++ b/common.go @@ -0,0 +1,48 @@ +// Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "os" + "path/filepath" + "strings" + "syscall" +) + +// IsNotExist tells you if err is an error that implies that either the path +// accessed does not exist (or path components don't exist). This is +// effectively a more broad version of [os.IsNotExist]. +func IsNotExist(err error) bool { + // Check that it's not actually an ENOTDIR, which in some cases is a more + // convoluted case of ENOENT (usually involving weird paths). + return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) +} + +// errUnsafeRoot is returned if the user provides SecureJoinVFS with a path +// that contains ".." components. +var errUnsafeRoot = errors.New("root path provided to SecureJoin contains '..' components") + +// hasDotDot checks if the path contains ".." components in a platform-agnostic +// way. +func hasDotDot(path string) bool { + // If we are on Windows, strip any volume letters. It turns out that + // C:..\foo may (or may not) be a valid pathname and we need to handle that + // leading "..". + path = stripVolume(path) + // Look for "/../" in the path, but we need to handle leading and trailing + // ".."s by adding separators. Doing this with filepath.Separator is ugly + // so just convert to Unix-style "/" first. + path = filepath.ToSlash(path) + return strings.Contains("/"+path+"/", "/../") +} + +// stripVolume just gets rid of the Windows volume included in a path. Based on +// some godbolt tests, the Go compiler is smart enough to make this a no-op on +// Linux. +func stripVolume(path string) string { + return path[len(filepath.VolumeName(path)):] +} diff --git a/join.go b/join.go index 52cc485..18860c9 100644 --- a/join.go +++ b/join.go @@ -3,10 +3,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !plan9 + package securejoin import ( - "errors" "os" "path/filepath" "strings" @@ -15,40 +16,6 @@ import ( const maxSymlinkLimit = 255 -// IsNotExist tells you if err is an error that implies that either the path -// accessed does not exist (or path components don't exist). This is -// effectively a more broad version of [os.IsNotExist]. -func IsNotExist(err error) bool { - // Check that it's not actually an ENOTDIR, which in some cases is a more - // convoluted case of ENOENT (usually involving weird paths). - return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) -} - -// errUnsafeRoot is returned if the user provides SecureJoinVFS with a path -// that contains ".." components. -var errUnsafeRoot = errors.New("root path provided to SecureJoin contains '..' components") - -// stripVolume just gets rid of the Windows volume included in a path. Based on -// some godbolt tests, the Go compiler is smart enough to make this a no-op on -// Linux. -func stripVolume(path string) string { - return path[len(filepath.VolumeName(path)):] -} - -// hasDotDot checks if the path contains ".." components in a platform-agnostic -// way. -func hasDotDot(path string) bool { - // If we are on Windows, strip any volume letters. It turns out that - // C:..\foo may (or may not) be a valid pathname and we need to handle that - // leading "..". - path = stripVolume(path) - // Look for "/../" in the path, but we need to handle leading and trailing - // ".."s by adding separators. Doing this with filepath.Separator is ugly - // so just convert to Unix-style "/" first. - path = filepath.ToSlash(path) - return strings.Contains("/"+path+"/", "/../") -} - // SecureJoinVFS joins the two given path components (similar to [filepath.Join]) except // that the returned path is guaranteed to be scoped inside the provided root // path (when evaluated). Any symbolic links in the path are evaluated with the diff --git a/join_plan9.go b/join_plan9.go new file mode 100644 index 0000000..8165a78 --- /dev/null +++ b/join_plan9.go @@ -0,0 +1,27 @@ +// Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import "path/filepath" + +// SecureJoin is equivalent to filepath.Join, as plan9 doesn't have symlinks. +func SecureJoin(root, unsafePath string) (string, error) { + // The root path must not contain ".." components, otherwise when we join + // the subpath we will end up with a weird path. We could work around this + // in other ways but users shouldn't be giving us non-lexical root paths in + // the first place. + if hasDotDot(root) { + return "", errUnsafeRoot + } + + unsafePath = filepath.Join(string(filepath.Separator), unsafePath) + return filepath.Join(root, unsafePath), nil +} + +// SecureJoinVFS is equivalent to filepath.Join, as plan9 doesn't have symlinks. +func SecureJoinVFS(root, unsafePath string, _ VFS) (string, error) { + return SecureJoin(root, unsafePath) +} diff --git a/join_test.go b/join_test.go index 7ec788c..b9894bb 100644 --- a/join_test.go +++ b/join_test.go @@ -13,7 +13,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // TODO: These tests won't work on plan9 because it doesn't have symlinks, and @@ -24,64 +23,9 @@ type input struct { expected string } -func expandedTempDir(t *testing.T) string { - dir := t.TempDir() - dir, err := filepath.EvalSymlinks(dir) - require.NoError(t, err) - return dir -} - -// Test basic handling of symlink expansion. -func TestSymlink(t *testing.T) { - dir := expandedTempDir(t) - - symlink(t, "somepath", filepath.Join(dir, "etc")) - symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) - symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd")) - - rootOrVol := string(filepath.Separator) - if vol := filepath.VolumeName(dir); vol != "" { - rootOrVol = vol + rootOrVol - } - - tc := []input{ - // Make sure that expansion with a root of '/' proceeds in the expected fashion. - {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")}, - {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")}, - - {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, - // Now test scoped expansion. - {dir, "passwd", filepath.Join(dir, "somepath", "passwd")}, - {dir, "etclink", filepath.Join(dir, "somepath")}, - {dir, "etc", filepath.Join(dir, "somepath")}, - {dir, "etc/test", filepath.Join(dir, "somepath", "test")}, - {dir, "etc/test/..", filepath.Join(dir, "somepath")}, - } - - for _, test := range tc { - got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - // This is only for OS X, where /etc is a symlink to /private/etc. In - // principle, SecureJoin(/, pth) is the same as EvalSymlinks(pth) in - // the case where the path exists. - if test.root == "/" { - if expected, err := filepath.EvalSymlinks(test.expected); err == nil { - test.expected = expected - } - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - } -} - // In a path without symlinks, SecureJoin is equivalent to Clean+Join. func TestNoSymlink(t *testing.T) { - dir := expandedTempDir(t) + dir := t.TempDir() tc := []input{ {dir, "somepath", filepath.Join(dir, "somepath")}, @@ -112,92 +56,6 @@ func TestNoSymlink(t *testing.T) { } } -// Make sure that .. is **not** expanded lexically. -func TestNonLexical(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) - symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) - symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) - symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - - for _, test := range []input{ - {dir, "subdir", filepath.Join(dir, "subdir")}, - {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/../test", filepath.Join(dir, "test")}, - // This is the divergence from a simple filepath.Clean implementation. - {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, - } { - got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - } -} - -// Make sure that symlink loops result in errors. -func TestSymlinkLoop(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link")) - symlink(t, "/subdir/link", filepath.Join(dir, "path")) - symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self")) - - for _, test := range []struct { - root, unsafe string - }{ - {dir, "subdir/link"}, - {dir, "path"}, - {dir, "../../path"}, - {dir, "subdir/link/../.."}, - {dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."}, - {dir, "self"}, - {dir, "self/.."}, - {dir, "/../../../../../../../../../../../../../../../../self/.."}, - {dir, "/self/././.."}, - } { - got, err := SecureJoin(test.root, test.unsafe) - if !errors.Is(err, syscall.ELOOP) { - t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) - continue - } - } -} - -// Make sure that ENOTDIR is correctly handled. -func TestEnotdir(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - writeFile(t, filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0o755) - symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link")) - - for _, test := range []struct { - root, unsafe string - }{ - {dir, "subdir/link"}, - {dir, "notdir"}, - {dir, "notdir/child"}, - } { - _, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - } -} - // Some silly tests to make sure that all error types are correctly handled. func TestIsNotExist(t *testing.T) { for _, test := range []struct { @@ -222,121 +80,6 @@ func TestIsNotExist(t *testing.T) { } } -type mockVFS struct { - lstat func(path string) (os.FileInfo, error) - readlink func(path string) (string, error) -} - -func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) } -func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(path) } - -// Make sure that SecureJoinVFS actually does use the given VFS interface. -func TestSecureJoinVFS(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) - symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) - symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) - symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - - for _, test := range []input{ - {dir, "subdir", filepath.Join(dir, "subdir")}, - {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/../test", filepath.Join(dir, "test")}, - // This is the divergence from a simple filepath.Clean implementation. - {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, - } { - var nLstat, nReadlink int - mock := mockVFS{ - lstat: func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) }, - readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) }, - } - - got, err := SecureJoinVFS(test.root, test.unsafe, mock) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - if nLstat == 0 && nReadlink == 0 { - t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe) - } - } -} - -// Make sure that SecureJoinVFS actually does use the given VFS interface, and -// that errors are correctly propagated. -func TestSecureJoinVFSErrors(t *testing.T) { - var ( - lstatErr = errors.New("lstat error") - readlinkErr = errors.New("readlink err") - ) - - dir := expandedTempDir(t) - - // Make a link. - symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link")) - - // Define some fake mock functions. - lstatFailFn := func(string) (os.FileInfo, error) { return nil, lstatErr } - readlinkFailFn := func(string) (string, error) { return "", readlinkErr } - - // Make sure that the set of {lstat, readlink} failures do propagate. - for idx, test := range []struct { - vfs VFS - expected []error - }{ - { - expected: []error{nil}, - vfs: mockVFS{ - lstat: os.Lstat, - readlink: os.Readlink, - }, - }, - { - expected: []error{lstatErr}, - vfs: mockVFS{ - lstat: lstatFailFn, - readlink: os.Readlink, - }, - }, - { - expected: []error{readlinkErr}, - vfs: mockVFS{ - lstat: os.Lstat, - readlink: readlinkFailFn, - }, - }, - { - expected: []error{lstatErr, readlinkErr}, - vfs: mockVFS{ - lstat: lstatFailFn, - readlink: readlinkFailFn, - }, - }, - } { - _, err := SecureJoinVFS(dir, "link", test.vfs) - - success := false - for _, exp := range test.expected { - if errors.Is(err, exp) { - success = true - } - } - if !success { - t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err) - } - } -} - func TestUncleanRoot(t *testing.T) { root := t.TempDir() diff --git a/symlink_test.go b/symlink_test.go new file mode 100644 index 0000000..9739639 --- /dev/null +++ b/symlink_test.go @@ -0,0 +1,276 @@ +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !plan9 + +package securejoin + +import ( + "errors" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/stretchr/testify/require" +) + +// TODO: These tests won't work on plan9 because it doesn't have symlinks, and +// also we use '/' here explicitly which probably won't work on Windows. + +func expandedTempDir(t *testing.T) string { + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) + return dir +} + +// Test basic handling of symlink expansion. +func TestSymlink(t *testing.T) { + dir := expandedTempDir(t) + + symlink(t, "somepath", filepath.Join(dir, "etc")) + symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) + symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd")) + + rootOrVol := string(filepath.Separator) + if vol := filepath.VolumeName(dir); vol != "" { + rootOrVol = vol + rootOrVol + } + + tc := []input{ + // Make sure that expansion with a root of '/' proceeds in the expected fashion. + {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")}, + {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")}, + + {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, + // Now test scoped expansion. + {dir, "passwd", filepath.Join(dir, "somepath", "passwd")}, + {dir, "etclink", filepath.Join(dir, "somepath")}, + {dir, "etc", filepath.Join(dir, "somepath")}, + {dir, "etc/test", filepath.Join(dir, "somepath", "test")}, + {dir, "etc/test/..", filepath.Join(dir, "somepath")}, + } + + for _, test := range tc { + got, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + // This is only for OS X, where /etc is a symlink to /private/etc. In + // principle, SecureJoin(/, pth) is the same as EvalSymlinks(pth) in + // the case where the path exists. + if test.root == "/" { + if expected, err := filepath.EvalSymlinks(test.expected); err == nil { + test.expected = expected + } + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + } +} + +// Make sure that .. is **not** expanded lexically. +func TestNonLexical(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) + symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) + symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) + symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) + + for _, test := range []input{ + {dir, "subdir", filepath.Join(dir, "subdir")}, + {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/../test", filepath.Join(dir, "test")}, + // This is the divergence from a simple filepath.Clean implementation. + {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, + } { + got, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + } +} + +// Make sure that symlink loops result in errors. +func TestSymlinkLoop(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link")) + symlink(t, "/subdir/link", filepath.Join(dir, "path")) + symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self")) + + for _, test := range []struct { + root, unsafe string + }{ + {dir, "subdir/link"}, + {dir, "path"}, + {dir, "../../path"}, + {dir, "subdir/link/../.."}, + {dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."}, + {dir, "self"}, + {dir, "self/.."}, + {dir, "/../../../../../../../../../../../../../../../../self/.."}, + {dir, "/self/././.."}, + } { + got, err := SecureJoin(test.root, test.unsafe) + if !errors.Is(err, syscall.ELOOP) { + t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) + continue + } + } +} + +// Make sure that ENOTDIR is correctly handled. +func TestEnotdir(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + writeFile(t, filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0o755) + symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link")) + + for _, test := range []struct { + root, unsafe string + }{ + {dir, "subdir/link"}, + {dir, "notdir"}, + {dir, "notdir/child"}, + } { + _, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + } +} + +type mockVFS struct { + lstat func(path string) (os.FileInfo, error) + readlink func(path string) (string, error) +} + +func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) } +func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(path) } + +// Make sure that SecureJoinVFS actually does use the given VFS interface. +func TestSecureJoinVFS(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) + symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) + symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) + symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) + + for _, test := range []input{ + {dir, "subdir", filepath.Join(dir, "subdir")}, + {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/../test", filepath.Join(dir, "test")}, + // This is the divergence from a simple filepath.Clean implementation. + {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, + } { + var nLstat, nReadlink int + mock := mockVFS{ + lstat: func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) }, + readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) }, + } + + got, err := SecureJoinVFS(test.root, test.unsafe, mock) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + if nLstat == 0 && nReadlink == 0 { + t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe) + } + } +} + +// Make sure that SecureJoinVFS actually does use the given VFS interface, and +// that errors are correctly propagated. +func TestSecureJoinVFSErrors(t *testing.T) { + var ( + lstatErr = errors.New("lstat error") + readlinkErr = errors.New("readlink err") + ) + + dir := expandedTempDir(t) + + // Make a link. + symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link")) + + // Define some fake mock functions. + lstatFailFn := func(string) (os.FileInfo, error) { return nil, lstatErr } + readlinkFailFn := func(string) (string, error) { return "", readlinkErr } + + // Make sure that the set of {lstat, readlink} failures do propagate. + for idx, test := range []struct { + vfs VFS + expected []error + }{ + { + expected: []error{nil}, + vfs: mockVFS{ + lstat: os.Lstat, + readlink: os.Readlink, + }, + }, + { + expected: []error{lstatErr}, + vfs: mockVFS{ + lstat: lstatFailFn, + readlink: os.Readlink, + }, + }, + { + expected: []error{readlinkErr}, + vfs: mockVFS{ + lstat: os.Lstat, + readlink: readlinkFailFn, + }, + }, + { + expected: []error{lstatErr, readlinkErr}, + vfs: mockVFS{ + lstat: lstatFailFn, + readlink: readlinkFailFn, + }, + }, + } { + _, err := SecureJoinVFS(dir, "link", test.vfs) + + success := false + for _, exp := range test.expected { + if errors.Is(err, exp) { + success = true + } + } + if !success { + t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err) + } + } +}