Skip to content

Commit d732831

Browse files
committed
cross attention
1 parent 40f7bf8 commit d732831

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/sensa/layers/attention.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,37 @@ def forward(
223223
o = RegularizeDP.apply(o, self.p)
224224
out = out + o
225225
return out
226+
227+
228+
class CrossAttention(torch.nn.Module):
229+
"""Cross-attention module.
230+
231+
Args:
232+
dim (int): Dimension of the query and key.
233+
dim_kv (int | None): Dimension of the key and value. If None, it is set to the same as the query.
234+
num_heads (int): Number of attention heads.
235+
"""
236+
237+
def __init__(self, dim: int, dim_kv: int | None, num_heads: int = 1):
238+
super().__init__()
239+
self.dim = dim
240+
self.dim_kv = dim_kv or dim
241+
self.num_heads = num_heads
242+
self.scale = (self.dim // self.num_heads) ** -0.5
243+
self.q = torch.nn.Linear(self.dim, self.dim, bias=False)
244+
self.k = torch.nn.Linear(self.dim_kv, self.dim, bias=False)
245+
self.v = torch.nn.Linear(self.dim_kv, self.dim, bias=False)
246+
self.o = torch.nn.Linear(self.dim, self.dim, bias=False)
247+
248+
def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
249+
b, n, c = q.shape
250+
q = self.q(q).reshape(b, n, self.num_heads, -1).transpose(1, 2)
251+
k = self.k(kv).reshape(b, kv.size(1), self.num_heads, -1).transpose(1, 2)
252+
v = self.v(kv).reshape(b, kv.size(1), self.num_heads, -1).transpose(1, 2)
253+
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
254+
o = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
255+
else:
256+
attn = (q @ k.transpose(-2, -1)) * self.scale
257+
attn = attn.softmax(dim=-1)
258+
o = attn @ v
259+
return self.o(o.transpose(1, 2).reshape(b, n, c))

tests/test_layers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ def test_attention():
2727
assert attn(tensor_masked, indices_to_keep=keep_idx).shape == tensor_masked.shape
2828

2929

30+
def test_cross_attention():
31+
"""Test sensa.layers.attention.CrossAttention module for shape correctness and forward pass."""
32+
batch_size = 2
33+
h, w = 4, 6
34+
h_kv, w_kv = 5, 7
35+
embed_dim = 32
36+
num_heads = 4
37+
38+
q = torch.randn(batch_size, h * w, embed_dim)
39+
kv = torch.randn(batch_size, h_kv * w_kv, embed_dim)
40+
41+
cross_attention = sensa.layers.attention.CrossAttention(dim=embed_dim, dim_kv=embed_dim, num_heads=num_heads)
42+
assert cross_attention(q, kv).shape == q.shape
43+
44+
3045
def test_dyt():
3146
r"""Test sensa.layers.DyT."""
3247
module = sensa.layers.DyT(16)

0 commit comments

Comments
 (0)