| | """Comprehensive tests for model architectures.""" |
| |
|
| | import pytest |
| | import torch |
| | import torch.nn as nn |
| | from models.simple_classifier import SimpleClassifier |
| | from models.cnn_classifier import CNNClassifier |
| |
|
| |
|
| | class TestSimpleClassifier: |
| | """Comprehensive tests for SimpleClassifier.""" |
| | |
| | def test_initialization_title_only(self): |
| | """Test initialization without snippets.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | assert hasattr(model, 'title_embedding') |
| | assert hasattr(model, 'fc') |
| | assert not hasattr(model, 'snippet_embedding') |
| | assert not hasattr(model, 'linear1') |
| | assert model.use_snippet is False |
| | |
| | def test_initialization_with_snippet(self): |
| | """Test initialization with snippets.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=True |
| | ) |
| | |
| | assert hasattr(model, 'title_embedding') |
| | assert hasattr(model, 'snippet_embedding') |
| | assert hasattr(model, 'linear1') |
| | assert hasattr(model, 'linear2') |
| | assert model.use_snippet is True |
| | |
| | def test_forward_title_only(self): |
| | """Test forward pass with title only.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | batch_size = 4 |
| | seq_len = 20 |
| | title = torch.randint(0, 1000, (batch_size, seq_len)) |
| | |
| | output = model(title) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| | assert not torch.isinf(output).any() |
| | |
| | def test_forward_with_snippet(self): |
| | """Test forward pass with snippet.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=True |
| | ) |
| | |
| | batch_size = 4 |
| | title_len = 20 |
| | snippet_len = 50 |
| | title = torch.randint(0, 1000, (batch_size, title_len)) |
| | snippet = torch.randint(0, 1000, (batch_size, snippet_len)) |
| | |
| | output = model(title, snippet) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| | assert not torch.isinf(output).any() |
| | |
| | def test_forward_gradient_flow(self): |
| | """Test that gradients flow through the model.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | output = model(title) |
| | |
| | loss = output.sum() |
| | loss.backward() |
| | |
| | |
| | assert model.title_embedding.weight.grad is not None |
| | assert model.fc.weight.grad is not None |
| | |
| | def test_different_batch_sizes(self): |
| | """Test model with different batch sizes.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | for batch_size in [1, 4, 16, 32]: |
| | title = torch.randint(0, 1000, (batch_size, 20)) |
| | output = model(title) |
| | assert output.shape[0] == batch_size |
| | |
| | def test_different_sequence_lengths(self): |
| | """Test model with different sequence lengths.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | for seq_len in [5, 10, 20, 50]: |
| | title = torch.randint(0, 1000, (4, seq_len)) |
| | output = model(title) |
| | assert output.shape == (4, 50) |
| |
|
| |
|
| | class TestCNNClassifier: |
| | """Comprehensive tests for CNNClassifier.""" |
| | |
| | def test_initialization(self): |
| | """Test CNN classifier initialization.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50 |
| | ) |
| | |
| | assert hasattr(model, 'title_embedding') |
| | assert hasattr(model, 'snippet_embedding') |
| | assert hasattr(model, 'title_conv_layers') |
| | assert hasattr(model, 'snippet_conv_layers') |
| | assert hasattr(model, 'classifier') |
| | |
| | def test_forward_pass(self): |
| | """Test CNN forward pass.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50 |
| | ) |
| | |
| | batch_size = 4 |
| | title = torch.randint(0, 1000, (batch_size, 20)) |
| | snippet = torch.randint(0, 1000, (batch_size, 50)) |
| | |
| | output = model(title, snippet) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| | assert not torch.isinf(output).any() |
| | |
| | def test_custom_conv_layers(self): |
| | """Test CNN with custom convolution configuration.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50, |
| | conv_channels=[64, 128, 256], |
| | kernel_sizes=[3, 5, 3] |
| | ) |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | snippet = torch.randint(0, 1000, (2, 50)) |
| | |
| | output = model(title, snippet) |
| | assert output.shape == (2, 50) |
| | |
| | def test_different_input_sizes(self): |
| | """Test CNN with expected input sequence lengths.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50 |
| | ) |
| | |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | snippet = torch.randint(0, 1000, (2, 50)) |
| | output = model(title, snippet) |
| | assert output.shape == (2, 50) |
| | |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | snippet = torch.randint(0, 1000, (2, 50)) |
| | output = model(title, snippet) |
| | assert output.shape == (2, 50) |
| | |
| | def test_gradient_flow(self): |
| | """Test that gradients flow through CNN.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50 |
| | ) |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | snippet = torch.randint(0, 1000, (2, 50)) |
| | output = model(title, snippet) |
| | |
| | loss = output.sum() |
| | loss.backward() |
| | |
| | |
| | assert model.title_embedding.weight.grad is not None |
| | assert model.snippet_embedding.weight.grad is not None |
| | assert model.classifier[1].weight.grad is not None |
| | |
| | def test_dropout_training_mode(self): |
| | """Test that dropout is active in training mode.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50, |
| | dropout=0.5 |
| | ) |
| | |
| | model.train() |
| | title = torch.randint(0, 1000, (4, 20)) |
| | snippet = torch.randint(0, 1000, (4, 50)) |
| | |
| | |
| | outputs = [model(title, snippet) for _ in range(5)] |
| | |
| | |
| | |
| | assert not all(torch.allclose(outputs[0], out) for out in outputs[1:]) |
| | |
| | def test_dropout_eval_mode(self): |
| | """Test that dropout is disabled in eval mode.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50, |
| | dropout=0.5 |
| | ) |
| | |
| | model.eval() |
| | title = torch.randint(0, 1000, (4, 20)) |
| | snippet = torch.randint(0, 1000, (4, 50)) |
| | |
| | |
| | outputs = [model(title, snippet) for _ in range(5)] |
| | |
| | |
| | assert all(torch.allclose(outputs[0], out) for out in outputs[1:]) |
| |
|
| |
|
| | class TestModelConsistency: |
| | """Tests for model consistency and edge cases.""" |
| | |
| | def test_models_handle_empty_batch(self): |
| | """Test that models handle edge case of empty batch gracefully.""" |
| | |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | |
| | |
| | title = torch.randint(0, 1000, (0, 20)) |
| | try: |
| | output = model(title) |
| | |
| | assert output.shape[0] == 0 |
| | except (RuntimeError, IndexError): |
| | |
| | pass |
| | |
| | def test_models_handle_single_sample(self): |
| | """Test models with batch size of 1.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | title = torch.randint(0, 1000, (1, 20)) |
| | output = model(title) |
| | |
| | assert output.shape == (1, 50) |
| | |
| | def test_models_handle_large_batch(self): |
| | """Test models with large batch size.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | title = torch.randint(0, 1000, (128, 20)) |
| | output = model(title) |
| | |
| | assert output.shape == (128, 50) |
| | |
| | def test_model_parameters_count(self): |
| | """Test that model has reasonable number of parameters.""" |
| | model = SimpleClassifier( |
| | vocab_size=10000, |
| | embedding_dim=300, |
| | output_dim=1000, |
| | use_snippet=False |
| | ) |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | |
| | |
| | |
| | |
| | |
| | assert 3_000_000 < total_params < 4_000_000 |
| |
|
| |
|