import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x # SiLU: Swish with beta=1
class Swish(nn.Module):
def __init__(self):
super().__init__()
self.beta = nn.Parameter(torch.Tensor([1.]))
def forward(self, x):
return x * F.sigmoid(self.beta * x) # beta가 살아있는 Swish