fix bug for training

This commit is contained in:
2024-09-12 15:11:09 +08:00
parent a79ca7749d
commit 4c69ed777b
15 changed files with 201 additions and 120 deletions

View File

@@ -38,7 +38,7 @@ class TransformerSequenceEncoder(nn.Module):
# Prepare mask for padding
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
# Transformer encoding
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)