diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index 9180d493..713082cf 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -312,7 +312,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x) g = self.gate(t) - experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True) + experts = torch.topk(g, k=self.experts_per_token, dim=-1) expert_weights = torch.nn.functional.softmax(experts.values, dim=1) expert_indices = experts.indices