update modules and pipeline
This commit is contained in:
12
modules/seq_encoder/abstract_seq_encoder.py
Normal file
12
modules/seq_encoder/abstract_seq_encoder.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||
pass
|
10
modules/seq_encoder/transformer_seq_encoder.py
Normal file
10
modules/seq_encoder/transformer_seq_encoder.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from torch import nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.module("transformer_seq_encoder")
|
||||
class TransformerSequenceEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerSequenceEncoder, self).__init__()
|
||||
self.config = config
|
||||
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||
pass
|
Reference in New Issue
Block a user