diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f1..87ee42d1 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -4,11 +4,18 @@ # torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/ import argparse +import os +from pathlib import Path from gpt_oss.tokenizer import get_tokenizer -def main(args): +def main(args: argparse.Namespace) -> None: + # Validate checkpoint path exists + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {args.checkpoint}") + match args.backend: case "torch": from gpt_oss.torch.utils import init_distributed diff --git a/gpt_oss/tokenizer.py b/gpt_oss/tokenizer.py index 866077f5..05c1080d 100644 --- a/gpt_oss/tokenizer.py +++ b/gpt_oss/tokenizer.py @@ -1,6 +1,7 @@ import tiktoken -def get_tokenizer(): + +def get_tokenizer() -> tiktoken.Encoding: o200k_base = tiktoken.get_encoding("o200k_base") tokenizer = tiktoken.Encoding( name="o200k_harmony", diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85d..47edf5c9 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -3,7 +3,7 @@ import torch.distributed as dist -def suppress_output(rank): +def suppress_output(rank: int) -> None: """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ builtin_print = __builtin__.print @@ -20,21 +20,60 @@ def print(*args, **kwargs): def init_distributed() -> torch.device: """Initialize the model for distributed inference.""" + # Check CUDA availability + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available. Please ensure CUDA is installed and accessible." + ) + # Initialize distributed inference world_size = int(os.environ.get("WORLD_SIZE", 1)) rank = int(os.environ.get("RANK", 0)) - if world_size > 1: - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=rank + + # Validate rank against available devices + if rank >= torch.cuda.device_count(): + raise RuntimeError( + f"Rank {rank} exceeds available CUDA devices ({torch.cuda.device_count()}). " + f"Please set RANK to a value between 0 and {torch.cuda.device_count() - 1}." ) - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") + + try: + if world_size > 1: + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) + + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + # Test device accessibility + try: + torch.cuda.get_device_properties(device) + except RuntimeError as e: + raise RuntimeError( + f"Failed to access CUDA device {rank}: {e}. " + "Please check device availability and permissions." + ) from e - # Warm up NCCL to avoid first-time latency - if world_size > 1: - x = torch.ones(1, device=device) - dist.all_reduce(x) - torch.cuda.synchronize(device) + # Warm up NCCL to avoid first-time latency + if world_size > 1: + try: + x = torch.ones(1, device=device) + dist.all_reduce(x) + torch.cuda.synchronize(device) + except RuntimeError as e: + raise RuntimeError( + f"Failed to initialize distributed communication on device {rank}: {e}" + ) from e - suppress_output(rank) - return device + suppress_output(rank) + return device + + except Exception as e: + # Clean up distributed process group if initialization failed + if world_size > 1 and dist.is_initialized(): + try: + dist.destroy_process_group() + except Exception: + pass # Ignore cleanup errors + raise diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 00000000..645e13b0 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,212 @@ +"""Unit tests for tokenizer encoding/decoding functionality.""" + +import pytest +from gpt_oss.tokenizer import get_tokenizer + + +class TestTokenizerBasics: + """Test basic tokenizer functionality.""" + + def test_get_tokenizer_returns_encoding(self): + """Test that get_tokenizer returns a valid encoding.""" + tokenizer = get_tokenizer() + assert tokenizer is not None + assert tokenizer.name == "o200k_harmony" + + def test_tokenizer_has_harmony_special_tokens(self): + """Test that tokenizer includes Harmony special tokens.""" + tokenizer = get_tokenizer() + special_tokens = tokenizer._special_tokens + + # Verify key Harmony tokens are present + assert "<|channel|>" in special_tokens + assert special_tokens["<|channel|>"] == 200005 + assert "<|start|>" in special_tokens + assert special_tokens["<|start|>"] == 200006 + assert "<|end|>" in special_tokens + assert special_tokens["<|end|>"] == 200007 + assert "<|message|>" in special_tokens + assert special_tokens["<|message|>"] == 200008 + assert "<|call|>" in special_tokens + assert special_tokens["<|call|>"] == 200012 + assert "<|return|>" in special_tokens + assert special_tokens["<|return|>"] == 200002 + + def test_tokenizer_has_reserved_tokens(self): + """Test that tokenizer includes reserved token range.""" + tokenizer = get_tokenizer() + special_tokens = tokenizer._special_tokens + + # Check reserved tokens exist in range + assert "<|reserved_200013|>" in special_tokens + assert special_tokens["<|reserved_200013|>"] == 200013 + assert "<|reserved_201087|>" in special_tokens + assert special_tokens["<|reserved_201087|>"] == 201087 + + +class TestTokenizerEncoding: + """Test tokenizer encoding functionality.""" + + def test_encode_simple_text(self): + """Test encoding simple text.""" + tokenizer = get_tokenizer() + text = "Hello, world!" + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + assert all(isinstance(t, int) for t in tokens) + + def test_encode_special_tokens(self): + """Test encoding text with special tokens.""" + tokenizer = get_tokenizer() + text = "<|channel|>final<|message|>Hello<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + + assert 200005 in tokens # <|channel|> + assert 200008 in tokens # <|message|> + assert 200002 in tokens # <|return|> + + def test_encode_without_special_allowed_raises(self): + """Test that encoding special tokens without permission raises error.""" + tokenizer = get_tokenizer() + text = "<|channel|>test" + + with pytest.raises(ValueError): + tokenizer.encode(text) + + def test_encode_empty_string(self): + """Test encoding empty string.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode("") + + assert isinstance(tokens, list) + assert len(tokens) == 0 + + def test_encode_unicode_text(self): + """Test encoding unicode text.""" + tokenizer = get_tokenizer() + text = "Hello δΈ–η•Œ 🌍" + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + + +class TestTokenizerDecoding: + """Test tokenizer decoding functionality.""" + + def test_decode_simple_tokens(self): + """Test decoding simple tokens.""" + tokenizer = get_tokenizer() + text = "Hello, world!" + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_decode_with_special_tokens(self): + """Test decoding tokens including special tokens.""" + tokenizer = get_tokenizer() + text = "<|channel|>final<|message|>Hello<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_decode_empty_list(self): + """Test decoding empty token list.""" + tokenizer = get_tokenizer() + decoded = tokenizer.decode([]) + + assert decoded == "" + + def test_decode_single_token(self): + """Test decoding single token.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode("a") + decoded = tokenizer.decode(tokens) + + assert decoded == "a" + + def test_decode_unicode(self): + """Test decoding unicode tokens.""" + tokenizer = get_tokenizer() + text = "Hello δΈ–η•Œ 🌍" + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + +class TestTokenizerRoundTrip: + """Test encode/decode round-trip consistency.""" + + @pytest.mark.parametrize("text", [ + "Simple text", + "Text with numbers: 123456", + "Special chars: !@#$%^&*()", + "Unicode: δ½ ε₯½δΈ–η•Œ", + "Emoji: πŸš€πŸŒŸπŸ’‘", + "Mixed: Hello δΈ–η•Œ 123 πŸŽ‰", + "Newlines:\nand\ttabs", + ]) + def test_roundtrip_consistency(self, text): + """Test that encode->decode returns original text.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_roundtrip_with_harmony_format(self): + """Test round-trip with Harmony message format.""" + tokenizer = get_tokenizer() + text = "<|channel|>analysis<|start|><|message|>Thinking...<|end|><|channel|>final<|message|>Answer<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + decoded = tokenizer.decode(tokens) + + assert decoded == text + + +class TestTokenizerEdgeCases: + """Test edge cases and error handling.""" + + def test_encode_very_long_text(self): + """Test encoding very long text.""" + tokenizer = get_tokenizer() + text = "a" * 10000 + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + + def test_decode_invalid_token_ids(self): + """Test decoding with potentially invalid token IDs.""" + tokenizer = get_tokenizer() + # Use valid token IDs from the special tokens + tokens = [200005, 200006, 200007] + decoded = tokenizer.decode(tokens) + + assert isinstance(decoded, str) + + def test_multiple_tokenizer_instances_consistent(self): + """Test that multiple tokenizer instances behave consistently.""" + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + text = "Test consistency" + tokens1 = tokenizer1.encode(text) + tokens2 = tokenizer2.encode(text) + + assert tokens1 == tokens2 + + def test_special_token_ids_immutable(self): + """Test that special token IDs are consistent.""" + tokenizer = get_tokenizer() + + # Get special tokens multiple times + channel_id_1 = tokenizer.encode("<|channel|>", allowed_special="all")[0] + channel_id_2 = tokenizer.encode("<|channel|>", allowed_special="all")[0] + + assert channel_id_1 == channel_id_2 == 200005