fix bug for training
This commit is contained in:
@@ -38,7 +38,7 @@ class TransformerSequenceEncoder(nn.Module):
|
||||
|
||||
# Prepare mask for padding
|
||||
max_len = max(lengths)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
||||
# Transformer encoding
|
||||
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
||||
|
||||
|
Reference in New Issue
Block a user