@@ -223,3 +223,37 @@ def forward(
223223 o = RegularizeDP .apply (o , self .p )
224224 out = out + o
225225 return out
226+
227+
228+ class CrossAttention (torch .nn .Module ):
229+ """Cross-attention module.
230+
231+ Args:
232+ dim (int): Dimension of the query and key.
233+ dim_kv (int | None): Dimension of the key and value. If None, it is set to the same as the query.
234+ num_heads (int): Number of attention heads.
235+ """
236+
237+ def __init__ (self , dim : int , dim_kv : int | None , num_heads : int = 1 ):
238+ super ().__init__ ()
239+ self .dim = dim
240+ self .dim_kv = dim_kv or dim
241+ self .num_heads = num_heads
242+ self .scale = (self .dim // self .num_heads ) ** - 0.5
243+ self .q = torch .nn .Linear (self .dim , self .dim , bias = False )
244+ self .k = torch .nn .Linear (self .dim_kv , self .dim , bias = False )
245+ self .v = torch .nn .Linear (self .dim_kv , self .dim , bias = False )
246+ self .o = torch .nn .Linear (self .dim , self .dim , bias = False )
247+
248+ def forward (self , q : torch .Tensor , kv : torch .Tensor ) -> torch .Tensor :
249+ b , n , c = q .shape
250+ q = self .q (q ).reshape (b , n , self .num_heads , - 1 ).transpose (1 , 2 )
251+ k = self .k (kv ).reshape (b , kv .size (1 ), self .num_heads , - 1 ).transpose (1 , 2 )
252+ v = self .v (kv ).reshape (b , kv .size (1 ), self .num_heads , - 1 ).transpose (1 , 2 )
253+ if hasattr (torch .nn .functional , "scaled_dot_product_attention" ):
254+ o = torch .nn .functional .scaled_dot_product_attention (q , k , v , dropout_p = 0.0 )
255+ else :
256+ attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
257+ attn = attn .softmax (dim = - 1 )
258+ o = attn @ v
259+ return self .o (o .transpose (1 , 2 ).reshape (b , n , c ))
0 commit comments