diff --git a/.gitignore b/.gitignore index c85bf12..3c665fb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,6 @@ __pycache__/ /*.egg-info -/test_data/file0.txt -/test_data/file1.bin /test_data/test.a /dist/ diff --git a/ar/substream.py b/ar/substream.py index b3e371b..9e02c37 100644 --- a/ar/substream.py +++ b/ar/substream.py @@ -1,3 +1,4 @@ +import errno import io class Substream(io.RawIOBase): @@ -8,16 +9,22 @@ def __init__(self, file: io.RawIOBase, start: int, size: int): self.size = size self.position = 0 - def seek(self, offset, origin=0): + def seek(self, offset, origin=0) -> int: if origin == 0: - self.position = offset + position = offset elif origin == 1: - self.position += offset + position = self.position + offset elif origin == 2: - self.position = self.size + offset + position = self.size + offset else: raise ValueError(f"Unexpected origin: {origin}") + if position < 0 or position > self.size: + raise OSError(errno.EINVAL, "Invalid argument") + + self.position = position + return self.position + def seekable(self): return True diff --git a/ar/tests/test_roundtrip.py b/ar/tests/test_roundtrip.py index 91bf802..7620014 100644 --- a/ar/tests/test_roundtrip.py +++ b/ar/tests/test_roundtrip.py @@ -1,5 +1,6 @@ # pylint: disable=redefined-outer-name import subprocess +import tempfile from pathlib import Path import pytest @@ -12,15 +13,23 @@ @pytest.fixture def simple_archive(): - # Create archive - TEST_DATA.mkdir(exist_ok=True) - - (TEST_DATA / 'file0.txt').write_text('Hello') - (TEST_DATA / 'file1.bin').write_bytes(b'\xc3\x28') # invalid utf-8 characters - (TEST_DATA / 'long_file_name_test0.txt').write_text('Hello2') - (TEST_DATA / 'long_file_name_test1.bin').write_bytes(b'\xc3\x28') - subprocess.check_call('ar r test.a file0.txt file1.bin long_file_name_test0.txt long_file_name_test1.bin'.split(), cwd=str(TEST_DATA)) - return TEST_DATA / 'test.a' + archive_path = TEST_DATA / 'test.a' + if archive_path.exists(): + return archive_path + + archive_full_path = archive_path.resolve() + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + (tmp_path / 'file0.txt').write_text('Hello') + (tmp_path / 'file1.bin').write_bytes(b'\xc3\x28') # invalid utf-8 characters + (tmp_path / 'long_file_name_test0.txt').write_text('Hello2') + (tmp_path / 'long_file_name_test1.bin').write_bytes(b'\xc3\x28') + subprocess.check_call( + ['ar', 'r', str(archive_full_path), 'file0.txt', 'file1.bin', 'long_file_name_test0.txt', 'long_file_name_test1.bin'], + cwd=tmpdir, + ) + + return archive_path def test_list(simple_archive): diff --git a/ar/tests/test_seek.py b/ar/tests/test_seek.py new file mode 100644 index 0000000..08cb052 --- /dev/null +++ b/ar/tests/test_seek.py @@ -0,0 +1,77 @@ +# pylint: disable=redefined-outer-name +import errno +from pathlib import Path + +import pytest + +from ar import Archive + + +ARCHIVE = Path('test_data/linux.a') + + +def test_seekable(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + assert file0.seekable() + + +def test_seek_from_start(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + file0.seek(2) + assert file0.tell() == 2 + assert file0.read(1) == 'l' + + +def test_seek_from_current(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + assert file0.read(1) == 'H' + file0.seek(1, 1) + assert file0.tell() == 2 + assert file0.read(2) == 'll' + + +def test_seek_from_end(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + file_size = 5 + file0.seek(-2, 2) + assert file0.tell() == file_size - 2 + assert file0.read() == 'lo' + + +def test_seek_before_start_raises_oserror(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + with pytest.raises(OSError) as excinfo: + file0.seek(-1) + assert excinfo.value.errno == errno.EINVAL + assert file0.tell() == 0 + + +def test_seek_beyond_end_raises_oserror(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + with pytest.raises(OSError) as excinfo: + file0.seek(6, 0) + assert excinfo.value.errno == errno.EINVAL + assert file0.tell() == 0 + + +def test_seek_from_current_outside_bounds_raises_oserror(): + with ARCHIVE.open('rb') as f: + archive = Archive(f) + file0 = archive.open('file0.txt', 'r') + assert file0.read(1) == 'H' + with pytest.raises(OSError) as excinfo: + file0.seek(-2, 1) + assert excinfo.value.errno == errno.EINVAL + assert file0.tell() == 1