diff --git a/vllm/config/model.py b/vllm/config/model.py index 8e6b2f6a8fa8..9ced8f75dea6 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1373,7 +1373,7 @@ def get_num_layers_by_block_type( if hasattr(self.hf_text_config, "model_type") and ( self.hf_text_config.model_type == "zaya" ): - return self.hf_text_config.num_layers + return self.hf_text_config.num_hidden_layers # Hybrid model Minimax attn_type_list = getattr(self.hf_config, "attn_type_list", None) if attn_type_list: diff --git a/vllm/model_executor/layers/mamba/cca.py b/vllm/model_executor/layers/mamba/cca.py index bfa68211098d..87125a869216 100644 --- a/vllm/model_executor/layers/mamba/cca.py +++ b/vllm/model_executor/layers/mamba/cca.py @@ -69,6 +69,7 @@ def __init__( self.head_dim = int(head_dim) self.latent_k_dim = self.num_k_heads * self.head_dim self.latent_q_dim = self.num_q_heads * self.head_dim + self.recurrent_v_dim = self.latent_k_dim // 2 self.sqrt_head_dim = np.sqrt(self.head_dim) self.gqa_groups = self.num_q_heads // self.num_k_heads assert self.num_q_heads % self.num_k_heads == 0, ( @@ -79,59 +80,57 @@ def __init__( ) * self.head_dim # Projections - self.linear_q = ReplicatedLinear( + self.q_proj = ReplicatedLinear( self.hidden_size, self.latent_q_dim, bias=self.config.attention_bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.linear_q", + prefix=f"{prefix}.q_proj", ) - self.linear_k = ReplicatedLinear( + self.k_proj = ReplicatedLinear( self.hidden_size, self.latent_k_dim, bias=self.config.attention_bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.linear_k", + prefix=f"{prefix}.k_proj", ) - self.val_proj1 = ReplicatedLinear( + self.v_proj_current = ReplicatedLinear( self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.val_proj1", + prefix=f"{prefix}.v_proj_current", ) - self.val_proj2 = ReplicatedLinear( + self.v_proj_delayed = ReplicatedLinear( self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.val_proj2", + prefix=f"{prefix}.v_proj_delayed", ) # Depthwise + grouped conv along sequence (exactly like Megatron) in_out_ch = self.latent_k_dim + self.latent_q_dim self.in_out_ch = in_out_ch - self.conv_qk = nn.Sequential( - nn.Conv1d( - in_channels=in_out_ch, - out_channels=in_out_ch, - kernel_size=self.cca_time0, - groups=in_out_ch, - padding=0, - stride=1, - ), - nn.Conv1d( - in_channels=in_out_ch, - out_channels=in_out_ch, - kernel_size=self.cca_time1, - groups=(self.num_k_heads + self.num_q_heads), - padding=0, - stride=1, - ), + self.conv_qk_depthwise = nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time0, + groups=in_out_ch, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=in_out_ch, + out_channels=in_out_ch, + kernel_size=self.cca_time1, + groups=(self.num_k_heads + self.num_q_heads), + padding=0, + stride=1, ) # Per-k head temperature (Megatron: shape [num_k_heads]) @@ -148,7 +147,7 @@ def forward_native( hidden_states: torch.Tensor, output: torch.Tensor, ): - return + self._forward_no_cache(hidden_states, output) def forward( self, @@ -161,6 +160,57 @@ def forward( self.prefix, ) + def _forward_no_cache( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ) -> None: + """Project an uncached contiguous token sequence into q/k/v.""" + num_tokens = hidden_states.shape[0] + hs = hidden_states.unsqueeze(1) # [S, 1, H] + + q = self.q_proj(hs) + k = self.k_proj(hs) + qk_packed0 = torch.cat([q, k], dim=-1) + del q + del k + + query_pre = qk_packed0[..., :self.latent_q_dim].view( + *qk_packed0.shape[:2], self.num_q_heads, self.head_dim) + key_base = qk_packed0[..., self.latent_q_dim:].view( + *qk_packed0.shape[:2], self.num_k_heads, self.head_dim) + + qk_packed2 = F.pad(qk_packed0.permute(1, 2, 0), + (self.total_padding, 0)) + qk_packed3 = self.conv_qk_grouped( + self.conv_qk_depthwise(qk_packed2)).permute(2, 0, 1) + + query = qk_packed3[..., :self.latent_q_dim].view( + *qk_packed3.shape[:2], self.num_q_heads, self.head_dim) + key = qk_packed3[..., self.latent_q_dim:].view( + *qk_packed3.shape[:2], self.num_k_heads, self.head_dim) + query, key = self._add_grouped_qk_means_inplace( + query, key, query_pre, key_base) + query, key = self._rms_normalize_qk(query.contiguous(), + key.contiguous()) + + value_current = self.v_proj_current(hs) + delayed_v_state = self.v_proj_delayed(hs) + zero_delayed = self.v_proj_delayed( + hidden_states.new_zeros(1, 1, self.hidden_size)) + value_delayed = torch.cat([zero_delayed, delayed_v_state[:-1]], dim=0) + value = torch.cat([value_current, value_delayed], dim=-1).contiguous() + value = value.view(num_tokens, 1, self.num_k_heads, self.head_dim) + + q_end = self.latent_q_dim + k_end = q_end + self.latent_k_dim + output[:num_tokens, :q_end] = query.reshape(num_tokens, + self.latent_q_dim) + output[:num_tokens, q_end:k_end] = key.reshape(num_tokens, + self.latent_k_dim) + output[:num_tokens, k_end:] = value.reshape(num_tokens, + self.latent_k_dim) + def _rms_normalize_qk( self, query: torch.Tensor, key: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: @@ -221,8 +271,8 @@ def _conv_qk_decode(self, x: torch.Tensor) -> torch.Tensor: Output: [N, C, S_out] """ # Stage 1: depthwise conv over sequence. - w0 = self.conv_qk[0].weight.squeeze(1) # [C, K0] - b0 = self.conv_qk[0].bias # [C] or None + w0 = self.conv_qk_depthwise.weight.squeeze(1) # [C, K0] + b0 = self.conv_qk_depthwise.bias # [C] or None x = x.to(w0.dtype) k0 = w0.shape[1] @@ -232,8 +282,8 @@ def _conv_qk_decode(self, x: torch.Tensor) -> torch.Tensor: mid = mid + b0[None, :, None] # Stage 2: grouped conv over the depthwise output. - w1 = self.conv_qk[1].weight # [C, D, K1] - b1 = self.conv_qk[1].bias # [C] or None + w1 = self.conv_qk_grouped.weight # [C, D, K1] + b1 = self.conv_qk_grouped.bias # [C] or None g = self.num_k_heads + self.num_q_heads d = self.head_dim k1 = w1.shape[2] @@ -257,7 +307,7 @@ def forward_cuda( attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, CCAAttentionMetadata) conv_states = self.kv_cache[0] - prev_hs = self.kv_cache[1] + recurrent_states = self.kv_cache[1] state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d if state_indices_tensor_d is not None and state_indices_tensor_d.dim() > 1: @@ -267,59 +317,8 @@ def forward_cuda( if attn_metadata is None: # V1 profile run - hs = hidden_states.unsqueeze(0).transpose(0, 1).contiguous() - hs_d = F.pad(hs[:-1], pad=(0, 0, 0, 0, 1, 0)) # [S, B, H] - q = self.linear_q(hs) # [S, B, latent_q_dim] - k = self.linear_k(hs) # [S, B, latent_k_dim] - qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k] - del q - del k - - # Pre-mean tensors in head form (for "qk_mean_{q,k}" calc) - query_pre = qk_packed0[..., : self.latent_q_dim].view( - *qk_packed0.shape[:2], self.num_q_heads, self.head_dim - ) # [S, B, qh, dh] - - key_base = qk_packed0[..., self.latent_q_dim :].view( - *qk_packed0.shape[:2], self.num_k_heads, self.head_dim - ) # [S, B, kh, dh] - - qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S] - qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0)) - qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E] - - # Build queries/keys from conv output + means - query = ( - qk_packed3[..., : self.latent_q_dim] - .view(*qk_packed3.shape[:2], self.num_q_heads, self.head_dim) - .float() - ) - - key = ( - qk_packed3[..., self.latent_q_dim :] - .view(*qk_packed3.shape[:2], self.num_k_heads, self.head_dim) - .float() - ) - query, key = self._add_grouped_qk_means_inplace( - query, key, query_pre, key_base - ) - del query_pre - del key_base - del qk_packed0 - del qk_packed3 - - # Values from the two time streams - v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2] - v2 = self.val_proj2(hs_d) # [S, B, latent_k_dim/2] - value = ( - torch.cat([v1, v2], dim=-1) - .contiguous() - .view(*hs.shape[:2], self.num_k_heads, self.head_dim) - ) # [S, B, kh, dh] - - query, key = self._rms_normalize_qk(query.contiguous(), key.contiguous()) - - return hs + self._forward_no_cache(hidden_states, output) + return num_prefills = attn_metadata.num_prefills # request count num_decodes = attn_metadata.num_decode_tokens # token count (=request) @@ -328,7 +327,6 @@ def forward_cuda( has_decode = num_decodes > 0 num_actual_tokens = num_decodes + num_prefill_tokens - num_input_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states[:num_actual_tokens] # Batch size is effectively 1 in this path, so insert the singleton @@ -336,8 +334,8 @@ def forward_cuda( hs = hidden_states.unsqueeze(1) # [S, 1, H] batch_size = hs.shape[1] - q = self.linear_q(hs) # [S, B, latent_q_dim] - k = self.linear_k(hs) # [S, B, latent_k_dim] + q = self.q_proj(hs) # [S, B, latent_q_dim] + k = self.k_proj(hs) # [S, B, latent_k_dim] qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k] del q del k @@ -364,14 +362,20 @@ def forward_cuda( [num_decodes, num_prefill_tokens], dim=0, ) + delayed_v_state = self.v_proj_delayed(hs[:num_actual_tokens]) + delayed_v_state_d, delayed_v_state_p = torch.split( + delayed_v_state, + [num_decodes, num_prefill_tokens], + dim=0, + ) qk_packed3 = torch.empty( (num_actual_tokens, batch_size, self.in_out_ch), device=hs.device, dtype=hs.dtype, ) - hs2 = torch.empty( - (num_actual_tokens, batch_size, self.hidden_size), + value_delayed = torch.empty( + (num_actual_tokens, batch_size, self.recurrent_v_dim), device=hs.device, dtype=hs.dtype, ) @@ -382,23 +386,20 @@ def forward_cuda( assert query_start_loc_p is not None # Prefill prefill_slice = slice(num_decodes, num_decodes + num_prefill_tokens) - hs2_prefill = hs2[prefill_slice] + value_delayed_prefill = value_delayed[prefill_slice] qk_packed3_prefill = qk_packed3[prefill_slice] for i in range(len(query_start_loc_p) - 1): start_i, end_i = query_start_loc_p[i], query_start_loc_p[i + 1] - hs2_cur = hs_p[start_i:end_i, :, :] # [S_cur, B, H] qk_packed0_cur = qk_packed0_p[start_i:end_i, :, :] # [S_cur, B, H] + delayed_v_state_cur = delayed_v_state_p[start_i:end_i] qk_packed1_cur = qk_packed0_cur.permute(1, 2, 0) # [1, H, S_cur] if has_initial_states_p[i]: - hs2_cached = ( - prev_hs[state_indices_tensor_p[i]].unsqueeze(0).unsqueeze(0) - ) # [1, 1, H] - if hs2_cached.dtype != hs2_cur.dtype: - hs2_cached = hs2_cached.to(hs2_cur.dtype) - hs2_cur = torch.cat( - [hs2_cached, hs2_cur[:-1]], dim=0 - ) # [S_cur, 1, H] + value_delayed_cached = recurrent_states[ + state_indices_tensor_p[i]].unsqueeze(0).unsqueeze(0) + if value_delayed_cached.dtype != value_delayed.dtype: + value_delayed_cached = value_delayed_cached.to( + value_delayed.dtype) qk_packed0_cached = conv_states[ state_indices_tensor_p[i] ].unsqueeze(0) # [1, H, total_padding] @@ -408,27 +409,31 @@ def forward_cuda( [qk_packed0_cached, qk_packed1_cur], dim=-1 ) # [1, H, S_cur + total_padding] else: - hs2_cur = F.pad(hs2_cur[:-1], pad=(0, 0, 0, 0, 1, 0)) + value_delayed_cached = self.v_proj_delayed( + hs_p.new_zeros(1, 1, self.hidden_size)) qk_packed2_cur = F.pad(qk_packed1_cur, (self.total_padding, 0)) - hs2_prefill[start_i:end_i] = hs2_cur + value_delayed_prefill[start_i:end_i] = torch.cat( + [value_delayed_cached, delayed_v_state_cur[:-1]], dim=0) conv_states_cur = nn.functional.pad( - qk_packed2_cur, (self.cca_time0 - qk_packed2_cur.shape[-1], 0) + qk_packed2_cur, + (self.total_padding - qk_packed2_cur.shape[-1], 0), ) conv_states[state_indices_tensor_p[i]] = conv_states_cur.to( device=conv_states.device, dtype=conv_states.dtype ) # Computing conv - qk_packed3_cur = self.conv_qk(qk_packed2_cur).permute( - 2, 0, 1 - ) # [S, B, E] + qk_packed3_cur = self.conv_qk_grouped( + self.conv_qk_depthwise(qk_packed2_cur) + ).permute(2, 0, 1) # [S, B, E] qk_packed3_prefill[start_i:end_i] = qk_packed3_cur - prev_hs[state_indices_tensor_p] = hs_p[query_start_loc_p[1:] - 1, 0, :].to( - device=prev_hs.device, dtype=prev_hs.dtype - ) + recurrent_states[state_indices_tensor_p] = delayed_v_state_p[ + query_start_loc_p[1:] - 1, 0, :].to( + device=recurrent_states.device, + dtype=recurrent_states.dtype) if has_decode: assert state_indices_tensor_d is not None @@ -490,38 +495,40 @@ def forward_cuda( device=conv_states.device, dtype=conv_states.dtype ) - hs2_decode = prev_hs[safe_decode_indices].unsqueeze(1) # [S, 1, H] - hs2_decode = torch.where( + value_delayed_decode = recurrent_states[safe_decode_indices].unsqueeze(1) + value_delayed_decode = torch.where( decode_is_pad.view(-1, 1, 1), - hs2_decode.new_zeros(()), - hs2_decode, + value_delayed_decode.new_zeros(()), + value_delayed_decode, ) - if hs2_decode.dtype != hs.dtype: - hs2_decode = hs2_decode.to(hs.dtype) - hs2[:num_decodes] = hs2_decode - new_prev_hs = hs_d[:, 0, :].to(prev_hs.dtype) - new_prev_hs = torch.where( + if value_delayed_decode.dtype != value_delayed.dtype: + value_delayed_decode = value_delayed_decode.to(value_delayed.dtype) + value_delayed[:num_decodes] = value_delayed_decode + new_recurrent_state = delayed_v_state_d[:, 0, :].to( + recurrent_states.dtype) + new_recurrent_state = torch.where( decode_is_pad.view(-1, 1), - new_prev_hs.new_zeros(()), - new_prev_hs, - ) - prev_hs[safe_decode_indices] = new_prev_hs.to( - device=prev_hs.device, dtype=prev_hs.dtype + new_recurrent_state.new_zeros(()), + new_recurrent_state, ) + recurrent_states[safe_decode_indices] = new_recurrent_state.to( + device=recurrent_states.device, + dtype=recurrent_states.dtype) del qk_packed0_d del qk_packed0_p del hs_d del hs_p + del delayed_v_state_d + del delayed_v_state_p # Values from the two time streams - v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2] - v2 = self.val_proj2(hs2) - value = torch.cat([v1, v2], dim=-1).contiguous() + v1 = self.v_proj_current(hs) # [S, B, latent_k_dim/2] + value = torch.cat([v1, value_delayed], dim=-1).contiguous() value = value.view( num_actual_tokens, batch_size, self.num_k_heads, self.head_dim ) # [S, B, kh, dh] - del hs2 + del value_delayed # Build queries/keys from conv output + means query = ( @@ -575,7 +582,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: num_k_heads=self.num_k_heads, num_q_heads=self.num_q_heads, head_dim=self.head_dim, - hidden_size=self.hidden_size, + recurrent_state_size=self.recurrent_v_dim, ) @property diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index a48358b2fae2..731151347985 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -214,15 +214,15 @@ def cca_state_shape( num_k_heads: int, num_q_heads: int, head_dim: int, - hidden_size: int, + recurrent_state_size: int, ) -> tuple[tuple[int, int], tuple[int]]: latent_k_dim = num_k_heads * head_dim latent_q_dim = num_q_heads * head_dim in_out_ch = latent_k_dim + latent_q_dim conv_states_shape = (in_out_ch, conv_kernel_size) - prev_hs_shape = (hidden_size,) - return (conv_states_shape, prev_hs_shape) + recurrent_state_shape = (recurrent_state_size,) + return (conv_states_shape, recurrent_state_shape) @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): diff --git a/vllm/model_executor/models/zaya.py b/vllm/model_executor/models/zaya.py index dbd68a0a1778..c48b38ca69bb 100644 --- a/vllm/model_executor/models/zaya.py +++ b/vllm/model_executor/models/zaya.py @@ -13,7 +13,6 @@ from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE @@ -55,53 +54,31 @@ def apply(self, layer, x, bias=None): return out -class ResidualScaling(nn.Module): - def __init__( - self, - config, - layer_n, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): +class ZayaResidualScaling(nn.Module): + def __init__(self, hidden_size: int): super().__init__() - self.config = config - self.not_first_layer = layer_n != 0 - self.hidden_states_scale = torch.nn.Parameter( - torch.ones(self.config.hidden_size) - ) - self.hidden_states_bias = torch.nn.Parameter( - torch.zeros(self.config.hidden_size) - ) - - if self.not_first_layer: - self.residual_scale = torch.nn.Parameter( - torch.ones(self.config.hidden_size) - ) - self.residual_bias = torch.nn.Parameter( - torch.zeros(self.config.hidden_size) - ) - - def forward(self, residual: torch.Tensor, hidden_states: torch.Tensor): - hs_expand_shape = (1,) * (hidden_states.dim() - 1) + (-1,) - hs_bias = self.hidden_states_bias.to(torch.float32).view(*hs_expand_shape) - hs_scale = self.hidden_states_scale.to(torch.float32).view(*hs_expand_shape) - hidden_states = (hidden_states.float() + hs_bias) * hs_scale - if self.not_first_layer and residual is not None: - res_expand_shape = (1,) * (residual.dim() - 1) + (-1,) - res_bias = self.residual_bias.to(torch.float32).view(*res_expand_shape) - res_scale = self.residual_scale.to(torch.float32).view(*res_expand_shape) - residual = (residual.float() + res_bias) * res_scale - return residual, hidden_states + self.hidden_states_scale = nn.Parameter(torch.ones(hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + hidden_states = ( + hidden_states.float() + self.hidden_states_bias.to(torch.float32) + ) * self.hidden_states_scale.to(torch.float32) + residual = ( + residual.float() + self.residual_bias.to(torch.float32) + ) * self.residual_scale.to(torch.float32) + return hidden_states + residual def _apply_norm_with_fp32_residual( - norm: nn.Module, residual: torch.Tensor, target_dtype: torch.dtype + norm: nn.Module, + residual: torch.Tensor, + target_dtype: torch.dtype, ) -> torch.Tensor: if isinstance(norm, RMSNorm): - # vLLM custom rms_norm requires x and weight dtypes to match. - # When residual stays fp32 for numerical hardening and weights are fp16, - # use the native path to avoid compiled-kernel dtype mismatch. if residual.dtype != norm.weight.dtype: hidden_states = norm.forward_native(residual) else: @@ -110,13 +87,31 @@ def _apply_norm_with_fp32_residual( return norm(residual.to(target_dtype)) +def _rope_parameters_for_layer(config: ZayaConfig, layer_idx: int) -> dict: + layer_types = getattr(config, "layer_types", None) or ["hybrid"] + layer_type = layer_types[layer_idx] + rope_parameters = getattr(config, "rope_parameters", None) + if isinstance(rope_parameters, dict): + params = rope_parameters.get(layer_type, rope_parameters) + if isinstance(params, dict): + params = dict(params) + else: + params = {} + else: + params = {} + params.setdefault("rope_type", "default") + params.setdefault("rope_theta", getattr(config, "rope_theta", 5000000)) + params.setdefault( + "partial_rotary_factor", getattr(config, "partial_rotary_factor", 0.5)) + return params + + class ZayaAttention(nn.Module): def __init__( self, config: ZayaConfig, - layer_idx, - layer_n, - prefix_name: str = "", + layer_idx: int, + prefix: str, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, @@ -124,378 +119,218 @@ def __init__( super().__init__() self.config = config self.layer_idx = layer_idx - self.layer_n = layer_n self.hidden_size = config.hidden_size - self.attention_dropout = config.attention_dropout - - self.cca_num_k_heads = config.num_query_groups - self.cca_num_q_heads = config.num_attention_heads - self.cca_time0 = config.cca_time0 - self.cca_time1 = config.cca_time1 + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim self.scale = self.head_dim**-0.5 - self.qkv = CCA( + self.qkv_proj = CCA( config=config, - cca_num_k_heads=self.cca_num_k_heads, - cca_num_q_heads=self.cca_num_q_heads, + cca_num_k_heads=self.num_key_value_heads, + cca_num_q_heads=self.num_attention_heads, hidden_size=self.hidden_size, head_dim=self.head_dim, - cca_time0=self.cca_time0, - cca_time1=self.cca_time1, - layer_number=layer_n, + cca_time0=config.cca_time0, + cca_time1=config.cca_time1, + layer_number=layer_idx, model_config=model_config, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix_name}.cca", + prefix=f"{prefix}.qkv_proj", ) self.o_proj = ReplicatedLinear( - self.cca_num_q_heads * self.head_dim, + self.num_attention_heads * self.head_dim, self.hidden_size, - bias=self.config.attention_bias, + bias=config.attention_bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix_name}.o_proj", + prefix=f"{prefix}.o_proj", ) - swa_layers = getattr(config, "swa_layers", None) - swa_window = swa_layers[layer_n] if swa_layers is not None else None - is_swa = swa_window is not None and swa_window != 0 - - if is_swa: - swa_window = swa_window + 1 - + layer_type = config.layer_types[layer_idx] + sliding_window = ( + config.sliding_window if layer_type == "hybrid_sliding" else None + ) self.attn = Attention( - self.cca_num_q_heads, + self.num_attention_heads, self.head_dim, self.scale, - self.cca_num_k_heads, - per_layer_sliding_window=swa_window if is_swa else None, + self.num_key_value_heads, + per_layer_sliding_window=sliding_window, cache_config=cache_config, - prefix=f"{prefix_name}.attn", - ) - - rope_theta = ( - getattr(config, "swa_rotary_base", config.rope_theta) - if is_swa - else config.rope_theta + prefix=f"{prefix}.attn", ) - self.rotary_emb = get_rope( head_size=self.head_dim, max_position=config.max_position_embeddings, is_neox_style=True, - rope_parameters={ - "rope_theta": rope_theta, - "rope_type": "default", - "partial_rotary_factor": 0.5, - }, + rope_parameters=_rope_parameters_for_layer(config, layer_idx), ) - self.q_dim = self.cca_num_q_heads * self.head_dim - self.k_dim = self.cca_num_k_heads * self.head_dim - self.v_dim = self.cca_num_k_heads * self.head_dim + self.q_dim = self.num_attention_heads * self.head_dim + self.k_dim = self.num_key_value_heads * self.head_dim + self.v_dim = self.num_key_value_heads * self.head_dim self.qkv_dim = self.q_dim + self.k_dim + self.v_dim def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: output_qkv = torch.zeros( (hidden_states.shape[0], self.qkv_dim), device=hidden_states.device, dtype=hidden_states.dtype, ) - self.qkv(hidden_states, output_qkv) + self.qkv_proj(hidden_states, output_qkv) q, k, v = output_qkv.split([self.q_dim, self.k_dim, self.v_dim], dim=-1) q, k = self.rotary_emb(position_ids, q, k) attn_output = self.attn(q, k, v) - attn_output = self.o_proj(attn_output) + return self.o_proj(attn_output) - return attn_output - -class ZayaDecoderATTLayer(nn.Module): +class ZayaRouterMLP(nn.Module): def __init__( self, - config: ZayaConfig, - layer_idx: str, - layer_n: int, - prefix_name="", - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, + hidden_size: int, + num_experts: int, + eps: float, quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() - self.config = config - self.layer_n = layer_n - self.training = self.training - self.self_attn = ZayaAttention( - config, - layer_idx, - layer_n, - prefix_name, - model_config=model_config, - cache_config=cache_config, + self.norm = RMSNorm(hidden_size, eps=eps) + self.fc1 = ReplicatedLinear( + hidden_size, + hidden_size, + bias=True, quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.fc1", ) - - if config.normalization == "RMSNorm": - self.input_norm = RMSNorm(self.config.hidden_size, eps=config.norm_epsilon) - elif config.normalization == "LayerNorm": - self.input_norm = nn.LayerNorm( - self.config.hidden_size, eps=config.norm_epsilon - ) - else: - raise TypeError("Normalization not supported.") - - if self.config.scale_residual_merge: - self.res_scale = ResidualScaling(config, layer_n) - - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - position_ids: torch.LongTensor, - layer_n: int, - prev_router_hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - layer_input_dtype = ( - self.input_norm.weight.dtype - if isinstance(self.input_norm, RMSNorm) - else hidden_states.dtype - ) - if self.config.scale_residual_merge: - residual, hidden_states = self.res_scale(residual, hidden_states) - if residual is not None: - residual = residual.float() + hidden_states.float() - else: - residual = hidden_states.float() - hidden_states = _apply_norm_with_fp32_residual( - self.input_norm, residual, layer_input_dtype + self.fc2 = ReplicatedLinear( + hidden_size, + hidden_size, + bias=True, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.fc2", ) - - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, + self.out_proj = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.out_proj", ) + self.act_fn = nn.GELU() - return hidden_states, residual, prev_router_hidden_states + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = _apply_norm_with_fp32_residual( + self.norm, hidden_states, self.norm.weight.dtype) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) class ZayaRouter(nn.Module): def __init__( self, - config, - layer_n: int, - num_moe_experts: int, - moe_router_topk: int, - mlp_expansion: int, - hidden_size: int | None = None, - layer_number: int | None = None, - cache_config: CacheConfig | None = None, + config: ZayaConfig, + layer_idx: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - ) -> None: + ): super().__init__() - - # ---- Config / shape ---- self.config = config - self.layer_n = layer_n - self.hidden_size = int(hidden_size or config.hidden_size) - self.layer_number = layer_number if layer_number is not None else 0 - # Reuse existing high-precision knob for router numerics. - self.router_softmax_fp32 = bool(getattr(config, "zaya_high_prec", False)) - - # MOD - self.use_mod = bool(getattr(config, "zaya_use_mod", False)) - self.mod_per = int(getattr(config, "zaya_mod_per", 0)) - if (self.mod_per == 0) and (num_moe_experts == 1): - raise ValueError( - "ERROR! The only way in which we can have a single expert is if" - " MOD is enabled." - ) - - # Expert counts (extra 'skip' expert if MOD) - self.num_experts = (num_moe_experts + 1) if self.use_mod else num_moe_experts - self.topk = int(moe_router_topk) - - # Router hidden dim - self.mlp_expansion = int(mlp_expansion) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + 1 + self.topk = config.num_experts_per_tok + self.router_hidden_size = config.router_hidden_size - # ---- Layers ---- self.down_proj = ReplicatedLinear( self.hidden_size, - self.mlp_expansion, + self.router_hidden_size, bias=True, quant_config=quant_config, return_bias=False, + prefix=f"{prefix}.down_proj", ) - - # EDA (depth-wise averaging) - zaya_first_layer = 1 - use_eda_cfg = bool(getattr(config, "zaya_use_eda", False)) - self.use_eda = ( - use_eda_cfg - and (zaya_first_layer is not None) - and (self.layer_number != zaya_first_layer) - ) - - ln_eps = float(getattr(config, "norm_epsilon", 1e-5)) - self.rmsnorm_eda = RMSNorm(self.mlp_expansion, eps=ln_eps) + self.use_eda = layer_idx != 0 if self.use_eda: - # eda - self.router_states_scale = nn.Parameter(torch.ones(self.mlp_expansion)) - - # routermlp - D = self.mlp_expansion - E = self.num_experts - self.non_linearity = nn.GELU() - self.router_mlp = nn.Sequential( - ReplicatedLinear( - D, D, bias=True, quant_config=quant_config, return_bias=False - ), - self.non_linearity, - ReplicatedLinear( - D, D, bias=True, quant_config=quant_config, return_bias=False - ), - self.non_linearity, - ReplicatedLinear( - D, E, bias=False, quant_config=quant_config, return_bias=False - ), + self.router_states_scale = nn.Parameter( + torch.ones(self.router_hidden_size)) + self.router_mlp = ZayaRouterMLP( + self.router_hidden_size, + self.num_experts, + config.rms_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.router_mlp", ) - - # Balancing biases self.register_buffer( - "balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32) - ) - if self.use_mod: - self.balancing_biases[-1] = -1.0 + "balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 def forward( self, - hidden_states: torch.Tensor, # (B, S, H) - prev_router_hidden_states: torch.Tensor - | None = None, # (B, S, D) previous router states for EDA + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute per-token expert probabilities and choose top-k experts. - - Args: - hidden_states: (batch, seq, hidden_size) - prev_router_hidden_states: (batch, seq, mlp_expansion) from prior - step/layer (for EDA). Optional. - - Returns: - route_prob: (batch*seq, topk) - expert_choice_t: (batch*seq, topk) int64 - router_hidden_states_next: (batch, seq, mlp_expansion) - """ - S, _ = hidden_states.shape - - # eda - hs = self.down_proj(hidden_states) - if self.use_eda and (prev_router_hidden_states is not None): - hs = hs + prev_router_hidden_states * self.router_states_scale - - # Stash the pre-norm states for the caller (this is what Megatron returns) - router_hidden_states_next = hs[-S:].clone() - - # 2) RMSNorm eda - hs_norm = self.rmsnorm_eda(hs) - - # 3) Expert probability distribution - logits = self.router_mlp(hs_norm) - if self.router_softmax_fp32: - # Keep router selection numerically stable without changing expert - # compute dtype. - expert_prob = torch.softmax(logits, dim=-1, dtype=torch.float32) - else: - expert_prob = torch.softmax(logits, dim=-1) - - # 4) expert choice with balancing biases (biases affect choice only, - # not the probabilities) - biased = expert_prob.detach().to(torch.float32) + self.balancing_biases - _, expert_choice_t = torch.topk(biased, self.topk, dim=-1) # (S, topk) - - # 5) If MOD and topk>1, once skip expert is selected, force all - # subsequent choices to skip as well, but this never happens since we use - # topk=1. - if (self.topk > 1) and self.use_mod: - skip_idx = self.num_experts - 1 - n_mask = expert_choice_t == skip_idx - cumsum_mask = torch.cumsum(n_mask, dim=-1) - expert_choice_t = expert_choice_t.masked_fill(cumsum_mask > 0, skip_idx) - - # Gather the probabilities for the selected experts - route_prob = torch.gather(expert_prob, dim=1, index=expert_choice_t) - if route_prob.dtype != hidden_states.dtype: - route_prob = route_prob.to(hidden_states.dtype) - - expert_choice_flat = expert_choice_t.reshape(-1, self.topk) - route_prob_flat = route_prob.reshape(-1, self.topk) - - return route_prob_flat, expert_choice_flat, router_hidden_states_next - - -class ZayaBlock(nn.Module): + seq_length = hidden_states.shape[0] + router_hidden_states = self.down_proj(hidden_states) + if self.use_eda and prev_router_hidden_states is not None: + router_hidden_states = ( + router_hidden_states + + prev_router_hidden_states * self.router_states_scale) + + router_hidden_states_next = router_hidden_states[-seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + biased_router_probs = ( + router_probs.detach().to(torch.float32) + self.balancing_biases) + _, router_indices = torch.topk(biased_router_probs, self.topk, dim=-1) + router_probs = torch.gather(router_probs, dim=1, index=router_indices) + + skip_expert = router_indices == self.config.num_experts + router_probs = router_probs.masked_fill(skip_expert, 0) + router_indices = router_indices.masked_fill(skip_expert, 0) + if router_probs.dtype != hidden_states.dtype: + router_probs = router_probs.to(hidden_states.dtype) + return router_probs, router_indices, router_hidden_states_next + + +class ZayaSparseMoeBlock(nn.Module): def __init__( self, config: ZayaConfig, layer_idx: int, - mlp_expansion: int, - ffn_hidden_size: int, - layer_n: int, - cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config - self.layer_n = layer_n - self.hidden_dim = config.hidden_size - self.num_moe_experts = layer_idx - self.mlp_expansion = mlp_expansion - - assert config.activation_func == "swiglu", "Only SwiGLU activation is supported" - assert config.gated_linear_unit, "gated_linear_unit must be True" - assert not config.add_bias_linear, "add_bias_linear must be False" - - self.router = ZayaRouter( - config=self.config, - layer_n=layer_n, - num_moe_experts=self.num_moe_experts, - moe_router_topk=getattr(self.config, "moe_router_topk", 1), - mlp_expansion=self.mlp_expansion, - hidden_size=self.hidden_dim, - layer_number=layer_n, - cache_config=cache_config, + self.layer_idx = layer_idx + self.topk = config.num_experts_per_tok + self.gate = ZayaRouter( + config, + layer_idx, quant_config=quant_config, - prefix=f"{prefix}.router", + prefix=f"{prefix}.gate", ) - self.topk = getattr(self.config, "moe_router_topk", 1) - - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > self.num_moe_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}." - ) - def _custom_routing_fn(hidden_states, gating_output, topk, renormalize): - # Routing results are packed into gating_output by forward(): - # columns [:topk] = weights (float), columns [topk:] = ids (float-cast) topk_weights = gating_output[:, :topk] topk_ids = gating_output[:, topk : 2 * topk].to(torch.int64) return topk_weights, topk_ids self.experts = FusedMoE( - num_experts=self.num_moe_experts, + num_experts=config.num_experts, top_k=self.topk, hidden_size=config.hidden_size, - intermediate_size=ffn_hidden_size // 2, + intermediate_size=config.moe_intermediate_size, renormalize=False, custom_routing_function=_custom_routing_fn, activation="silu", @@ -507,96 +342,74 @@ def forward( self, hidden_states: torch.Tensor, prev_router_hidden_states: torch.Tensor | None = None, - ): - probs, indices, router_hidden_states_out = self.router( + ) -> tuple[torch.Tensor, torch.Tensor]: + probs, indices, router_hidden_states_out = self.gate( hidden_states, prev_router_hidden_states=prev_router_hidden_states, ) - - if self.config.zaya_use_mod: - clamped_indices = torch.clamp(indices, min=0, max=self.num_moe_experts - 1) - packed_logits = torch.cat([probs, clamped_indices.to(probs.dtype)], dim=-1) - hidden_states_experts = self.experts(hidden_states, packed_logits) - hidden_states_mod = hidden_states * probs - if self.tp_size > 1: - hidden_states_mod = tensor_model_parallel_all_reduce(hidden_states_mod) - mod_mask = indices != self.num_moe_experts - hidden_states = (mod_mask * hidden_states_experts) + ( - (~mod_mask) * hidden_states_mod - ) - else: - packed_logits = torch.cat([probs, indices.to(probs.dtype)], dim=-1) - hidden_states = self.experts(hidden_states, packed_logits) - + packed_routing = torch.cat([probs, indices.to(probs.dtype)], dim=-1) + hidden_states = self.experts(hidden_states, packed_routing) return hidden_states, router_hidden_states_out -class ZayaDecoderMLPLayer(nn.Module): +class ZayaDecoderLayer(nn.Module): def __init__( self, config: ZayaConfig, layer_idx: int, - mlp_expansion: int, - ffn_hidden_size: int, - layer_n: int, + prefix: str, + model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, - prefix: str = "", ): super().__init__() - self.config = config - self.layer_n = layer_n - self.zaya_block = ZayaBlock( + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = ZayaAttention( config, layer_idx, - mlp_expansion, - ffn_hidden_size, - layer_n, + f"{prefix}.self_attn", + model_config=model_config, cache_config=cache_config, quant_config=quant_config, - prefix=prefix, ) - - if config.normalization == "RMSNorm": - self.input_norm = RMSNorm(self.config.hidden_size, eps=config.norm_epsilon) - elif config.normalization == "LayerNorm": - self.input_norm = nn.LayerNorm( - self.config.hidden_size, eps=config.norm_epsilon - ) - else: - raise TypeError("Normalization not supported.") - - if self.config.scale_residual_merge: - self.res_scale = ResidualScaling(config, layer_n) + self.post_attention_residual_scale = ZayaResidualScaling( + config.hidden_size) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.mlp = ZayaSparseMoeBlock( + config, + layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.post_mlp_residual_scale = ZayaResidualScaling(config.hidden_size) def forward( self, hidden_states: torch.Tensor, - residual: torch.Tensor, position_ids: torch.LongTensor, - layer_n: int, prev_router_hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - layer_input_dtype = ( - self.input_norm.weight.dtype - if isinstance(self.input_norm, RMSNorm) - else hidden_states.dtype - ) - if self.config.scale_residual_merge: - residual, hidden_states = self.res_scale(residual, hidden_states) - if residual is not None: - residual = residual.float() + hidden_states.float() - else: - residual = hidden_states.float() + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + layer_input_dtype = self.input_layernorm.weight.dtype hidden_states = _apply_norm_with_fp32_residual( - self.input_norm, residual, layer_input_dtype - ) + self.input_layernorm, residual, layer_input_dtype) - hidden_states, prev_router_hidden_states = self.zaya_block( - hidden_states, prev_router_hidden_states + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, ) - return hidden_states, residual, prev_router_hidden_states + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = _apply_norm_with_fp32_residual( + self.post_attention_layernorm, + residual, + self.post_attention_layernorm.weight.dtype, + ) + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, prev_router_hidden_states) + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + return hidden_states, prev_router_hidden_states @support_torch_compile @@ -609,85 +422,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - is_lora_enabled = bool(lora_config) - assert not is_lora_enabled - - self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + assert not lora_config self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.layers = [] - - # Initialize token embeddings self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) - - for layer_n in range(config.num_hidden_layers): - if layer_n % 2 == 1: - prefix_name = f"{prefix}.layers.{layer_n}.moe" - self.layers.append( - ZayaDecoderMLPLayer( - config, - config.num_experts, - config.zaya_mlp_expansion, - config.ffn_hidden_size, - layer_n, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix_name, - ) - ) - else: - prefix_name = f"{prefix}.layers.{layer_n}.self_attn" - self.layers.append( - ZayaDecoderATTLayer( - config, - "a", - layer_n, - prefix_name, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ) - ) - self.layers = nn.ModuleList(self.layers) - - if self.config.scale_residual_merge: - self.res_scale = ResidualScaling(config, config.num_hidden_layers) - - if config.normalization == "RMSNorm": - self.final_norm = RMSNorm(self.config.hidden_size, eps=config.norm_epsilon) - elif config.normalization == "LayerNorm": - self.final_norm = nn.LayerNorm( - self.config.hidden_size, eps=config.norm_epsilon - ) - else: - raise TypeError("Normalization not supported.") + self.layers = nn.ModuleList([ + ZayaDecoderLayer( + config, + layer_idx, + f"{prefix}.layers.{layer_idx}", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ) for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) + ["hidden_states"], config.hidden_size) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings. - - Args: - input_ids: Tensor of input token IDs - - Returns: - Embedded representation of the input tokens - """ return self.embed_tokens(input_ids) def forward( @@ -699,34 +461,21 @@ def forward( ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - residual = None - hidden_states = inputs_embeds + hidden_states = ( + inputs_embeds.float() + + self.input_hidden_states_bias.to(torch.float32) + ) * self.input_hidden_states_scale.to(torch.float32) prev_router_hidden_states = None - for layer_n, decoder_layer in enumerate(self.layers): - hidden_states, residual, prev_router_hidden_states = decoder_layer( + for decoder_layer in self.layers: + hidden_states, prev_router_hidden_states = decoder_layer( hidden_states, - residual, positions, - layer_n, prev_router_hidden_states, ) - if self.config.scale_residual_merge: - residual, hidden_states = self.res_scale(residual, hidden_states) - final_input_dtype = ( - self.final_norm.weight.dtype - if isinstance(self.final_norm, RMSNorm) - else hidden_states.dtype - ) - if residual is not None: - hidden_states = hidden_states.float() + residual.float() - else: - hidden_states = hidden_states.float() hidden_states = _apply_norm_with_fp32_residual( - self.final_norm, hidden_states, final_input_dtype - ) - + self.norm, hidden_states, self.norm.weight.dtype) return hidden_states @@ -745,23 +494,19 @@ def get_mamba_state_dtype_from_config( def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - ) -> tuple[tuple[int, int], tuple[int, int, int]]: + ) -> tuple[tuple[int, int], tuple[int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - conv_kernel_size = hf_config.cca_time0 - num_k_heads = hf_config.num_query_groups - num_q_heads = hf_config.num_attention_heads - head_dim = hf_config.head_dim - hidden_size = hf_config.hidden_size - return MambaStateShapeCalculator.cca_state_shape( tp_world_size=parallel_config.tensor_parallel_size, - conv_kernel_size=conv_kernel_size, - num_k_heads=num_k_heads, - num_q_heads=num_q_heads, - head_dim=head_dim, - hidden_size=hidden_size, + conv_kernel_size=(hf_config.cca_time0 - 1) + (hf_config.cca_time1 - 1), + num_k_heads=hf_config.num_key_value_heads, + num_q_heads=hf_config.num_attention_heads, + head_dim=hf_config.head_dim, + recurrent_state_size=( + hf_config.num_key_value_heads * hf_config.head_dim // 2 + ), ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -769,18 +514,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert config.moe_router_topk == 1, "Only topk=1 is supported in Zaya!" + assert config.num_experts_per_tok == 1, "Only topk=1 is supported in Zaya!" assert not cache_config.enable_prefix_caching, ( - "Zaya currently does not support prefix caching" - ) + "Zaya currently does not support prefix caching") tp_world_size = get_tensor_model_parallel_world_size() if tp_world_size > 1: logger.warning( - "WARNING: TP>1 detected, CCA does not support TP at the moment," - " but it's still going to work without actual splits, meaning " - "every rank will run as if TP=1" - ) + "TP>1 detected; CCA currently replicates heads on every rank.") super().__init__() self.config = config @@ -791,8 +532,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model_config = vllm_config.model_config self.model = ZayaModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -802,28 +542,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), + if not lora_config else lora_config.lora_vocab_padding_size), quant_config=None, bias=config.lm_head_bias, ) - # Tie weights with input embeddings if using same dimensions - if self.config.tie_word_embeddings: + if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) - + self.unpadded_vocab_size, config.vocab_size) if bool(getattr(config, "zaya_high_prec", False)): self.lm_head.quant_method = _FP32EmbeddingMethod() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) + self.model.make_empty_intermediate_tensors) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -836,25 +567,10 @@ def forward( inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - """Compute logits for next token prediction. - - Args: - hidden_states: Hidden states from model forward pass + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - Returns: - Logits for next token prediction - """ - logits = self.logits_processor(self.lm_head, hidden_states) - return logits + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) @@ -863,93 +579,97 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if "cos_sin_cache" in key: continue params_dict[key] = buffer - - weights_dict = {} - for key, loaded_weight in weights: - if "lora" in key: - if "_A.weight" in key: - key = key.replace("_A.weight", ".A.weight") - elif "_B.weight" in key: - key = key.replace("_B.weight", ".B.weight") - weights_dict[key] = loaded_weight - - # Build a map from prefix → FusedMoE module for expert weight loading - fused_moe_modules: dict[str, FusedMoE] = {} - for name, module in self.named_modules(): - if isinstance(module, FusedMoE): - fused_moe_modules[name] = module + fused_moe_modules = { + name: module + for name, module in self.named_modules() + if isinstance(module, FusedMoE) + } loaded_params: set[str] = set() - import re + tp_rank = get_tensor_model_parallel_rank() + disable_tqdm = tp_rank != 0 import tqdm - tp_rank = get_tensor_model_parallel_rank() - disable_tqdm = tp_rank != 0 + skipped_weights: list[str] = [] for chkpt_weight_name, loaded_weight in tqdm.tqdm( - weights_dict.items(), + weights, desc="Loading weights", unit_scale=True, unit="weights", disable=disable_tqdm, ): - if "local_experts" in chkpt_weight_name: - parts = chkpt_weight_name.split(".") + weight_name = chkpt_weight_name + if weight_name.endswith(".self_attn.qk_norm.temp"): + weight_name = weight_name.replace( + ".self_attn.qk_norm.temp", + ".self_attn.qkv_proj.temp", + ) - m = re.search(r"\.local_experts\.(\d+)\.", chkpt_weight_name) - if not m: - raise ValueError( - f"Could not parse expert id from {chkpt_weight_name}" - ) - expert_id = int(m.group(1)) + if "lora" in weight_name: + if "_A.weight" in weight_name: + weight_name = weight_name.replace("_A.weight", ".A.weight") + elif "_B.weight" in weight_name: + weight_name = weight_name.replace("_B.weight", ".B.weight") - # Determine FusedMoE param name and shard_id - # linear_fc1 = merged gate+up → w13_weight (split into w1, w3) - # linear_fc2 = down proj → w2_weight (shard_id w2) - fused_moe_prefix = ".".join(parts[:5]) + if weight_name.endswith(".mlp.experts.gate_up_proj"): + fused_moe_prefix = weight_name.removesuffix(".gate_up_proj") fused_moe_module = fused_moe_modules.get(fused_moe_prefix) if fused_moe_module is None: - logger.warning( - "No FusedMoE module found at %s, skipping %s", - fused_moe_prefix, - chkpt_weight_name, - ) + skipped_weights.append(chkpt_weight_name) continue - if parts[-2] == "linear_fc1": - param_name = f"{fused_moe_prefix}.w13_weight" - param = params_dict[param_name] - half = loaded_weight.shape[0] // 2 - gate_weight = loaded_weight[:half, :] - up_weight = loaded_weight[half:, :] + param_name = f"{fused_moe_prefix}.w13_weight" + param = params_dict[param_name] + gate_weight, up_weight = loaded_weight.chunk(2, dim=1) + for expert_id, (gate_expert, up_expert) in enumerate( + zip(gate_weight, up_weight)): fused_moe_module.weight_loader( - param, gate_weight, chkpt_weight_name, "w1", expert_id + param, + gate_expert, + param_name, + "w1", + expert_id, ) fused_moe_module.weight_loader( - param, up_weight, chkpt_weight_name, "w3", expert_id + param, + up_expert, + param_name, + "w3", + expert_id, ) - loaded_params.add(param_name) - elif parts[-2] == "linear_fc2": - param_name = f"{fused_moe_prefix}.w2_weight" - param = params_dict[param_name] + loaded_params.add(param_name) + continue + + if weight_name.endswith(".mlp.experts.down_proj"): + fused_moe_prefix = weight_name.removesuffix(".down_proj") + fused_moe_module = fused_moe_modules.get(fused_moe_prefix) + if fused_moe_module is None: + skipped_weights.append(chkpt_weight_name) + continue + + param_name = f"{fused_moe_prefix}.w2_weight" + param = params_dict[param_name] + for expert_id, down_expert in enumerate(loaded_weight): fused_moe_module.weight_loader( - param, loaded_weight, chkpt_weight_name, "w2", expert_id - ) - loaded_params.add(param_name) - else: - logger.warning( - "Unknown expert weight kind in %s", chkpt_weight_name + param, + down_expert, + param_name, + "w2", + expert_id, ) + loaded_params.add(param_name) continue - # Loading other parameters - if chkpt_weight_name not in params_dict: - logger.info( - "WARNING: key {chkpt_weight_name} not in params! Skipping loading" - ) + if weight_name not in params_dict: + skipped_weights.append(weight_name) continue - param = params_dict[chkpt_weight_name] + param = params_dict[weight_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(chkpt_weight_name) + loaded_params.add(weight_name) + if skipped_weights: + raise RuntimeError( + "Unexpected Zaya checkpoint weights were not loaded: " + f"{sorted(skipped_weights)}") return loaded_params diff --git a/vllm/transformers_utils/configs/zaya.py b/vllm/transformers_utils/configs/zaya.py index 6f90553e0902..c9e1e5163647 100644 --- a/vllm/transformers_utils/configs/zaya.py +++ b/vllm/transformers_utils/configs/zaya.py @@ -6,113 +6,117 @@ class ZayaConfig(PretrainedConfig): model_type = "zaya" keys_to_ignore_at_inference = ["past_key_values"] + ignore_keys_at_rope_validation = {"hybrid", "hybrid_sliding"} def __init__( self, - cca=True, - num_query_groups=2, use_cache=True, attention_bias=False, lm_head_bias=False, vocab_size=262272, hidden_size=2048, - ffn_hidden_size=4096, - num_hidden_layers=80, + num_hidden_layers=40, num_experts=16, num_attention_heads=8, head_dim=128, - activation_func="swiglu", + hidden_act="silu", max_position_embeddings=131072, - norm_epsilon=1e-05, + initializer_range=0.02, + rms_norm_eps=1e-05, pad_token_id=0, bos_token_id=2, eos_token_id=106, tie_word_embeddings=True, - rope_theta=5000000, attention_dropout=0.0, - moe_router_topk=1, - normalization="RMSNorm", - zaya_mlp_expansion=256, - zaya_use_mod=True, - zaya_high_prec=True, - zaya_use_eda=True, - add_bias_linear=False, - gated_linear_unit=True, - scale_residual_merge=True, - fused_add_norm=False, - residual_in_fp32=True, - apply_rope_fusion=True, - bias_activation_fusion=True, - activation_func_fp8_input_store=False, + moe_intermediate_size=2048, + num_experts_per_tok=1, + output_router_logits=False, + layer_types=None, sliding_window=None, - rope_scaling=None, rope_parameters=None, + rope_scaling=None, partial_rotary_factor=0.5, num_key_value_heads=2, - clamp_temp=False, cca_time0=2, cca_time1=2, - swa_layers=None, - swa_rotary_base=None, _attn_implementation="eager", **kwargs, ): - self.cca = cca - self.num_query_groups = num_query_groups self.use_cache = use_cache self.attention_bias = attention_bias self.lm_head_bias = lm_head_bias self.vocab_size = vocab_size self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size self.num_hidden_layers = num_hidden_layers self.num_experts = num_experts self.num_attention_heads = num_attention_heads self.head_dim = head_dim assert self.head_dim is not None - assert self.num_query_groups == num_key_value_heads self.num_key_value_heads = num_key_value_heads - self.activation_func = activation_func + self.num_query_groups = num_key_value_heads + self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings - self.norm_epsilon = norm_epsilon - self.normalization = normalization + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.tie_word_embeddings = tie_word_embeddings self.attention_dropout = attention_dropout - self.moe_router_topk = moe_router_topk - self.zaya_mlp_expansion = zaya_mlp_expansion - self.zaya_use_mod = zaya_use_mod - self.zaya_high_prec = zaya_high_prec - self.zaya_use_eda = zaya_use_eda - self.add_bias_linear = add_bias_linear - self.gated_linear_unit = gated_linear_unit - self.scale_residual_merge = scale_residual_merge - self.residual_in_fp32 = residual_in_fp32 - self.bias_activation_fusion = bias_activation_fusion - self.activation_func_fp8_input_store = activation_func_fp8_input_store + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.output_router_logits = output_router_logits + self.layer_types = ( + ["hybrid"] * num_hidden_layers if layer_types is None else list(layer_types) + ) self.sliding_window = sliding_window self.partial_rotary_factor = partial_rotary_factor - self.rope_theta = rope_theta if isinstance(rope_parameters, dict): rope_parameters = dict(rope_parameters) elif isinstance(rope_scaling, dict): rope_parameters = dict(rope_scaling) else: - rope_parameters = {"rope_type": "default"} + rope_parameters = { + "hybrid": { + "rope_type": "default", + "rope_theta": 5000000, + "partial_rotary_factor": partial_rotary_factor, + }, + "hybrid_sliding": { + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": partial_rotary_factor, + }, + } if "type" in rope_parameters: rope_parameters.setdefault("rope_type", rope_parameters.pop("type")) - rope_parameters.setdefault("rope_theta", rope_theta) - rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) + if "hybrid" in rope_parameters or "hybrid_sliding" in rope_parameters: + rope_parameters.pop("rope_type", None) self.rope_parameters = rope_parameters - self.num_key_value_heads = num_key_value_heads - self.clamp_temp = clamp_temp self.cca_time0 = cca_time0 self.cca_time1 = cca_time1 - self.swa_layers = swa_layers - self.swa_rotary_base = swa_rotary_base self._attn_implementation = _attn_implementation + self.rope_theta = self._rope_theta_for_layer_type("hybrid") + + # Compatibility aliases used by existing vLLM helper code. + self.cca = True + self.ffn_hidden_size = 2 * moe_intermediate_size + self.activation_func = "swiglu" + self.norm_epsilon = rms_norm_eps + self.normalization = "RMSNorm" + self.moe_router_topk = num_experts_per_tok + self.zaya_mlp_expansion = router_hidden_size = kwargs.pop( + "router_hidden_size", 256 + ) + self.router_hidden_size = router_hidden_size + self.zaya_use_mod = True + self.zaya_high_prec = True + self.zaya_use_eda = True + self.add_bias_linear = False + self.gated_linear_unit = True + self.scale_residual_merge = True + self.residual_in_fp32 = True + self.clamp_temp = False super().__init__( pad_token_id=pad_token_id, @@ -121,3 +125,9 @@ def __init__( tie_word_embeddings=self.tie_word_embeddings, **kwargs, ) + + def _rope_theta_for_layer_type(self, layer_type: str) -> float: + layer_rope = self.rope_parameters.get(layer_type, self.rope_parameters) + if isinstance(layer_rope, dict): + return layer_rope.get("rope_theta", 5000000) + return 5000000