Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,7 @@ venv.bak/
/site

# mypy
.mypy_cache/
.mypy_cache/

# Claude
.claude/*
1,630 changes: 1,630 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
[tool.poetry]
name = "gan-models"
version = "0.1.0"
description = "A collection of GAN models for text generation"
authors = ["Your Name <your.email@example.com>"]
readme = "README.md"
license = "LICENSE"
packages = [
{ include = "models" },
{ include = "metrics" },
{ include = "utils" },
{ include = "instructor" },
{ include = "run" },
{ include = "visual" }
]

[tool.poetry.dependencies]
python = "^3.8"
torch = ">=1.0.0"
numpy = "^1.21.0"
nltk = ">=3.4.5"
tqdm = "4.32.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.0"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py", "tests.py"]
python_classes = ["Test*", "*Tests"]
python_functions = ["test_*"]
addopts = [
"-ra",
"--strict-markers",
"--strict-config",
"--cov=models",
"--cov=metrics",
"--cov=utils",
"--cov=instructor",
"--cov=run",
"--cov=visual",
"--cov-branch",
"--cov-report=term-missing:skip-covered",
"--cov-report=html",
"--cov-report=xml",
"--cov-fail-under=0",
]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow tests",
]

[tool.coverage.run]
source = ["models", "metrics", "utils", "instructor", "run", "visual"]
omit = [
"*/tests/*",
"*/__pycache__/*",
"*/site-packages/*",
"setup.py",
"config.py",
"main.py",
]

[tool.coverage.report]
precision = 2
show_missing = true
skip_covered = false
fail_under = 0
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
167 changes: 167 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pytest
import tempfile
import shutil
import os
import torch
import numpy as np
from pathlib import Path


@pytest.fixture
def temp_dir():
"""Create a temporary directory that is cleaned up after the test."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def mock_config():
"""Provide a mock configuration object for testing."""
class MockConfig:
def __init__(self):
self.vocab_size = 5000
self.gen_embed_dim = 32
self.gen_hidden_dim = 32
self.dis_embed_dim = 32
self.dis_hidden_dim = 32
self.max_seq_len = 20
self.batch_size = 64
self.num_rep = 64
self.gen_lr = 0.01
self.dis_lr = 0.01
self.update_rate = 0.8
self.temperature = 1.0
self.training_data = "oracle"
self.dataset = "oracle"
self.model_type = "vanilla"
self.gen_init = "normal"
self.dis_init = "uniform"
self.eval_type = "standard"
self.tips = 2000
self.temp_adpt = "exp"
self.oracle_pretrain = True
self.dis_pretrain = True
self.adv_g_step = 1
self.rollout_num = 16
self.gen_pretrain_steps = 120
self.dis_pretrain_steps = 50
self.log_file = "log/test_log.txt"
self.save_root = "save/test/"
self.signal_file = "run_signal.txt"
self.tips = 2000

return MockConfig()


@pytest.fixture
def sample_tensor():
"""Provide a sample tensor for testing."""
return torch.randn(64, 20) # batch_size x seq_len


@pytest.fixture
def sample_numpy_array():
"""Provide a sample numpy array for testing."""
return np.random.randn(64, 20)


@pytest.fixture
def mock_data_loader():
"""Provide a mock data loader for testing."""
class MockDataLoader:
def __init__(self, batch_size=64):
self.batch_size = batch_size
self.num_batch = 100

def __iter__(self):
for _ in range(self.num_batch):
yield torch.randint(0, 5000, (self.batch_size, 20))

def __len__(self):
return self.num_batch

def reset_pointer(self):
pass

return MockDataLoader()


@pytest.fixture
def mock_generator():
"""Provide a mock generator model for testing."""
class MockGenerator(torch.nn.Module):
def __init__(self, vocab_size=5000, embed_dim=32, hidden_dim=32):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim

def forward(self, x):
batch_size = x.size(0) if x.dim() > 0 else 1
return torch.randn(batch_size, 20, self.vocab_size)

def sample(self, batch_size, seq_len):
return torch.randint(0, self.vocab_size, (batch_size, seq_len))

return MockGenerator()


@pytest.fixture
def mock_discriminator():
"""Provide a mock discriminator model for testing."""
class MockDiscriminator(torch.nn.Module):
def __init__(self, vocab_size=5000, embed_dim=32, hidden_dim=32):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim

def forward(self, x):
batch_size = x.size(0)
return torch.rand(batch_size, 1)

return MockDiscriminator()


@pytest.fixture
def clean_logs(temp_dir):
"""Ensure log directories are clean for testing."""
log_dir = temp_dir / "log"
save_dir = temp_dir / "save"
log_dir.mkdir(exist_ok=True)
save_dir.mkdir(exist_ok=True)
yield log_dir, save_dir


@pytest.fixture
def mock_oracle_data(temp_dir):
"""Create mock oracle data for testing."""
oracle_file = temp_dir / "oracle.txt"
with open(oracle_file, 'w') as f:
for _ in range(10000):
# Generate random sequences of integers
seq = ' '.join(str(np.random.randint(0, 5000)) for _ in range(20))
f.write(seq + '\n')
return oracle_file


@pytest.fixture(autouse=True)
def setup_random_seeds():
"""Set random seeds for reproducibility in tests."""
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)


@pytest.fixture
def gpu_available():
"""Check if GPU is available for testing."""
return torch.cuda.is_available()


@pytest.fixture
def device():
"""Provide the appropriate device for testing."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
Empty file added tests/integration/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions tests/test_setup_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest
import sys
import os
from pathlib import Path


def test_pytest_is_installed():
"""Test that pytest is properly installed."""
assert pytest.__version__


def test_project_structure_exists():
"""Test that the basic project structure exists."""
project_root = Path(__file__).parent.parent

# Check main directories
assert (project_root / "models").exists()
assert (project_root / "metrics").exists()
assert (project_root / "utils").exists()
assert (project_root / "instructor").exists()
assert (project_root / "run").exists()
assert (project_root / "visual").exists()

# Check test directories
assert (project_root / "tests").exists()
assert (project_root / "tests" / "unit").exists()
assert (project_root / "tests" / "integration").exists()


def test_config_imports():
"""Test that basic imports work."""
try:
import torch
import numpy
import nltk
import tqdm
assert True
except ImportError as e:
pytest.fail(f"Failed to import required package: {e}")


def test_fixtures_available(temp_dir, mock_config, sample_tensor):
"""Test that pytest fixtures from conftest.py are available."""
assert temp_dir.exists()
assert mock_config.vocab_size == 5000
assert sample_tensor.shape == (64, 20)


@pytest.mark.unit
def test_unit_marker():
"""Test that the unit marker works."""
assert True


@pytest.mark.integration
def test_integration_marker():
"""Test that the integration marker works."""
assert True


@pytest.mark.slow
def test_slow_marker():
"""Test that the slow marker works."""
import time
time.sleep(0.1)
assert True


def test_coverage_target_modules_importable():
"""Test that modules targeted for coverage are importable."""
modules_to_test = [
'models',
'metrics',
'utils',
'instructor',
'run',
'visual'
]

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

for module in modules_to_test:
try:
__import__(module)
except Exception as e:
# Some modules might have dependencies, that's ok for this test
pass


def test_mock_fixtures(mock_generator, mock_discriminator, mock_data_loader):
"""Test that mock fixtures work properly."""
# Test generator
gen_output = mock_generator.sample(32, 20)
assert gen_output.shape == (32, 20)

# Test discriminator
dis_input = gen_output
dis_output = mock_discriminator(dis_input)
assert dis_output.shape == (32, 1)

# Test data loader
batch_count = 0
for batch in mock_data_loader:
batch_count += 1
assert batch.shape == (64, 20)
if batch_count >= 5: # Just test a few batches
break
assert batch_count == 5
Empty file added tests/unit/__init__.py
Empty file.