diff --git a/modules.py b/modules.py index 763dcae..75513ca 100644 --- a/modules.py +++ b/modules.py @@ -272,7 +272,7 @@ def forward(self, inputs): if __name__ == '__main__': num_units = 512 inputs = Variable(torch.randn((100, 10))) - outputs = position_encoding(num_units)(inputs) + outputs = positional_encoding(num_units)(inputs) outputs = multihead_attention(num_units)(outputs, outputs, outputs) outputs = feedforward(num_units)(outputs)