From ba9bfa97726aa41d30d634605a70fc897e1160d5 Mon Sep 17 00:00:00 2001 From: ansschh Date: Sat, 9 Aug 2025 11:25:51 -0700 Subject: [PATCH 1/3] feat: Add type hints and validation to core utility modules - Add return type hint to get_tokenizer() - Add type hints and checkpoint validation to generate.py main() - Add parameter type hints to suppress_output() in torch/utils.py Improves IDE support and catches potential bugs early. --- gpt_oss/generate.py | 9 ++++++++- gpt_oss/tokenizer.py | 3 ++- gpt_oss/torch/utils.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index dfaaa6f1..87ee42d1 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -4,11 +4,18 @@ # torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/ import argparse +import os +from pathlib import Path from gpt_oss.tokenizer import get_tokenizer -def main(args): +def main(args: argparse.Namespace) -> None: + # Validate checkpoint path exists + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {args.checkpoint}") + match args.backend: case "torch": from gpt_oss.torch.utils import init_distributed diff --git a/gpt_oss/tokenizer.py b/gpt_oss/tokenizer.py index 866077f5..05c1080d 100644 --- a/gpt_oss/tokenizer.py +++ b/gpt_oss/tokenizer.py @@ -1,6 +1,7 @@ import tiktoken -def get_tokenizer(): + +def get_tokenizer() -> tiktoken.Encoding: o200k_base = tiktoken.get_encoding("o200k_base") tokenizer = tiktoken.Encoding( name="o200k_harmony", diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85d..680f4afb 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -3,7 +3,7 @@ import torch.distributed as dist -def suppress_output(rank): +def suppress_output(rank: int) -> None: """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ builtin_print = __builtin__.print From 006cba74f68dd5c38b66de05ee558168f3e04da0 Mon Sep 17 00:00:00 2001 From: ansschh Date: Fri, 15 Aug 2025 02:02:54 -0700 Subject: [PATCH 2/3] feat: Add comprehensive error handling for CUDA device initialization - Add CUDA availability check before device initialization - Validate rank against available CUDA device count - Add device accessibility testing with clear error messages - Add error handling for distributed communication setup - Add cleanup for failed distributed process group initialization - Provide helpful error messages with troubleshooting guidance This prevents cryptic CUDA errors and provides clear feedback when: - CUDA is not available - Invalid device rank is specified - Device access fails - Distributed communication fails --- gpt_oss/torch/utils.py | 63 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index 680f4afb..47edf5c9 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -20,21 +20,60 @@ def print(*args, **kwargs): def init_distributed() -> torch.device: """Initialize the model for distributed inference.""" + # Check CUDA availability + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available. Please ensure CUDA is installed and accessible." + ) + # Initialize distributed inference world_size = int(os.environ.get("WORLD_SIZE", 1)) rank = int(os.environ.get("RANK", 0)) - if world_size > 1: - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=rank + + # Validate rank against available devices + if rank >= torch.cuda.device_count(): + raise RuntimeError( + f"Rank {rank} exceeds available CUDA devices ({torch.cuda.device_count()}). " + f"Please set RANK to a value between 0 and {torch.cuda.device_count() - 1}." ) - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") + + try: + if world_size > 1: + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) + + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + # Test device accessibility + try: + torch.cuda.get_device_properties(device) + except RuntimeError as e: + raise RuntimeError( + f"Failed to access CUDA device {rank}: {e}. " + "Please check device availability and permissions." + ) from e - # Warm up NCCL to avoid first-time latency - if world_size > 1: - x = torch.ones(1, device=device) - dist.all_reduce(x) - torch.cuda.synchronize(device) + # Warm up NCCL to avoid first-time latency + if world_size > 1: + try: + x = torch.ones(1, device=device) + dist.all_reduce(x) + torch.cuda.synchronize(device) + except RuntimeError as e: + raise RuntimeError( + f"Failed to initialize distributed communication on device {rank}: {e}" + ) from e - suppress_output(rank) - return device + suppress_output(rank) + return device + + except Exception as e: + # Clean up distributed process group if initialization failed + if world_size > 1 and dist.is_initialized(): + try: + dist.destroy_process_group() + except Exception: + pass # Ignore cleanup errors + raise From 17bef25aa1362863a2fb5f2bb53f5d19c91b0db7 Mon Sep 17 00:00:00 2001 From: ansschh Date: Thu, 2 Oct 2025 16:29:51 -0700 Subject: [PATCH 3/3] Add comprehensive unit tests for tokenizer - Test basic encoding/decoding functionality - Test Harmony special token handling - Test round-trip consistency - Test edge cases and error handling - Verify reserved token range --- tests/test_tokenizer.py | 212 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 tests/test_tokenizer.py diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 00000000..645e13b0 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,212 @@ +"""Unit tests for tokenizer encoding/decoding functionality.""" + +import pytest +from gpt_oss.tokenizer import get_tokenizer + + +class TestTokenizerBasics: + """Test basic tokenizer functionality.""" + + def test_get_tokenizer_returns_encoding(self): + """Test that get_tokenizer returns a valid encoding.""" + tokenizer = get_tokenizer() + assert tokenizer is not None + assert tokenizer.name == "o200k_harmony" + + def test_tokenizer_has_harmony_special_tokens(self): + """Test that tokenizer includes Harmony special tokens.""" + tokenizer = get_tokenizer() + special_tokens = tokenizer._special_tokens + + # Verify key Harmony tokens are present + assert "<|channel|>" in special_tokens + assert special_tokens["<|channel|>"] == 200005 + assert "<|start|>" in special_tokens + assert special_tokens["<|start|>"] == 200006 + assert "<|end|>" in special_tokens + assert special_tokens["<|end|>"] == 200007 + assert "<|message|>" in special_tokens + assert special_tokens["<|message|>"] == 200008 + assert "<|call|>" in special_tokens + assert special_tokens["<|call|>"] == 200012 + assert "<|return|>" in special_tokens + assert special_tokens["<|return|>"] == 200002 + + def test_tokenizer_has_reserved_tokens(self): + """Test that tokenizer includes reserved token range.""" + tokenizer = get_tokenizer() + special_tokens = tokenizer._special_tokens + + # Check reserved tokens exist in range + assert "<|reserved_200013|>" in special_tokens + assert special_tokens["<|reserved_200013|>"] == 200013 + assert "<|reserved_201087|>" in special_tokens + assert special_tokens["<|reserved_201087|>"] == 201087 + + +class TestTokenizerEncoding: + """Test tokenizer encoding functionality.""" + + def test_encode_simple_text(self): + """Test encoding simple text.""" + tokenizer = get_tokenizer() + text = "Hello, world!" + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + assert all(isinstance(t, int) for t in tokens) + + def test_encode_special_tokens(self): + """Test encoding text with special tokens.""" + tokenizer = get_tokenizer() + text = "<|channel|>final<|message|>Hello<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + + assert 200005 in tokens # <|channel|> + assert 200008 in tokens # <|message|> + assert 200002 in tokens # <|return|> + + def test_encode_without_special_allowed_raises(self): + """Test that encoding special tokens without permission raises error.""" + tokenizer = get_tokenizer() + text = "<|channel|>test" + + with pytest.raises(ValueError): + tokenizer.encode(text) + + def test_encode_empty_string(self): + """Test encoding empty string.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode("") + + assert isinstance(tokens, list) + assert len(tokens) == 0 + + def test_encode_unicode_text(self): + """Test encoding unicode text.""" + tokenizer = get_tokenizer() + text = "Hello δΈ–η•Œ 🌍" + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + + +class TestTokenizerDecoding: + """Test tokenizer decoding functionality.""" + + def test_decode_simple_tokens(self): + """Test decoding simple tokens.""" + tokenizer = get_tokenizer() + text = "Hello, world!" + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_decode_with_special_tokens(self): + """Test decoding tokens including special tokens.""" + tokenizer = get_tokenizer() + text = "<|channel|>final<|message|>Hello<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_decode_empty_list(self): + """Test decoding empty token list.""" + tokenizer = get_tokenizer() + decoded = tokenizer.decode([]) + + assert decoded == "" + + def test_decode_single_token(self): + """Test decoding single token.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode("a") + decoded = tokenizer.decode(tokens) + + assert decoded == "a" + + def test_decode_unicode(self): + """Test decoding unicode tokens.""" + tokenizer = get_tokenizer() + text = "Hello δΈ–η•Œ 🌍" + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + +class TestTokenizerRoundTrip: + """Test encode/decode round-trip consistency.""" + + @pytest.mark.parametrize("text", [ + "Simple text", + "Text with numbers: 123456", + "Special chars: !@#$%^&*()", + "Unicode: δ½ ε₯½δΈ–η•Œ", + "Emoji: πŸš€πŸŒŸπŸ’‘", + "Mixed: Hello δΈ–η•Œ 123 πŸŽ‰", + "Newlines:\nand\ttabs", + ]) + def test_roundtrip_consistency(self, text): + """Test that encode->decode returns original text.""" + tokenizer = get_tokenizer() + tokens = tokenizer.encode(text) + decoded = tokenizer.decode(tokens) + + assert decoded == text + + def test_roundtrip_with_harmony_format(self): + """Test round-trip with Harmony message format.""" + tokenizer = get_tokenizer() + text = "<|channel|>analysis<|start|><|message|>Thinking...<|end|><|channel|>final<|message|>Answer<|return|>" + tokens = tokenizer.encode(text, allowed_special="all") + decoded = tokenizer.decode(tokens) + + assert decoded == text + + +class TestTokenizerEdgeCases: + """Test edge cases and error handling.""" + + def test_encode_very_long_text(self): + """Test encoding very long text.""" + tokenizer = get_tokenizer() + text = "a" * 10000 + tokens = tokenizer.encode(text) + + assert isinstance(tokens, list) + assert len(tokens) > 0 + + def test_decode_invalid_token_ids(self): + """Test decoding with potentially invalid token IDs.""" + tokenizer = get_tokenizer() + # Use valid token IDs from the special tokens + tokens = [200005, 200006, 200007] + decoded = tokenizer.decode(tokens) + + assert isinstance(decoded, str) + + def test_multiple_tokenizer_instances_consistent(self): + """Test that multiple tokenizer instances behave consistently.""" + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + text = "Test consistency" + tokens1 = tokenizer1.encode(text) + tokens2 = tokenizer2.encode(text) + + assert tokens1 == tokens2 + + def test_special_token_ids_immutable(self): + """Test that special token IDs are consistent.""" + tokenizer = get_tokenizer() + + # Get special tokens multiple times + channel_id_1 = tokenizer.encode("<|channel|>", allowed_special="all")[0] + channel_id_2 = tokenizer.encode("<|channel|>", allowed_special="all")[0] + + assert channel_id_1 == channel_id_2 == 200005