18 lines
614 B
Python
18 lines
614 B
Python
|
import torch
|
||
|
import numpy as np
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class GaussianFourierProjection(nn.Module):
|
||
|
"""Gaussian random features for encoding time steps."""
|
||
|
|
||
|
def __init__(self, embed_dim, scale=30.):
|
||
|
super().__init__()
|
||
|
# Randomly sample weights during initialization. These weights are fixed
|
||
|
# during optimization and are not trainable.
|
||
|
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||
|
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|