|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | | -from typing import Optional, Tuple, Union |
15 | 14 |
|
16 | 15 | import torch |
17 | 16 | import torch.nn as nn |
@@ -41,7 +40,7 @@ def __init__( |
41 | 40 | causal: bool = False, |
42 | 41 | sequence_length: int | None = None, |
43 | 42 | rel_pos_embedding: str | None = None, |
44 | | - input_size: Tuple | None = None, |
| 43 | + input_size: tuple | None = None, |
45 | 44 | attention_dtype: torch.dtype | None = None, |
46 | 45 | include_fc: bool = True, |
47 | 46 | use_combined_linear: bool = True, |
@@ -101,16 +100,16 @@ def __init__( |
101 | 100 |
|
102 | 101 | self.num_heads = num_heads |
103 | 102 | self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size |
104 | | - self.out_proj: Union[nn.Linear, nn.Identity] |
| 103 | + self.out_proj: nn.Linear | nn.Identity |
105 | 104 | if include_fc: |
106 | 105 | self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) |
107 | 106 | else: |
108 | 107 | self.out_proj = nn.Identity() |
109 | 108 |
|
110 | | - self.qkv: Union[nn.Linear, nn.Identity] |
111 | | - self.to_q: Union[nn.Linear, nn.Identity] |
112 | | - self.to_k: Union[nn.Linear, nn.Identity] |
113 | | - self.to_v: Union[nn.Linear, nn.Identity] |
| 109 | + self.qkv: nn.Linear | nn.Identity |
| 110 | + self.to_q: nn.Linear | nn.Identity |
| 111 | + self.to_k: nn.Linear | nn.Identity |
| 112 | + self.to_v: nn.Linear | nn.Identity |
114 | 113 |
|
115 | 114 | if use_combined_linear: |
116 | 115 | self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) |
@@ -153,7 +152,7 @@ def __init__( |
153 | 152 | ) |
154 | 153 | self.input_size = input_size |
155 | 154 |
|
156 | | - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): |
| 155 | + def forward(self, x, attn_mask: torch.Tensor | None = None): |
157 | 156 | """ |
158 | 157 | Args: |
159 | 158 | x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C |
|
0 commit comments