Skip to content

apply_query_key_layer_scaling ignored in PyTorch ≥ 2.0 scaled_dot_product_attention path #1363

Description

@Qi-Zhan

System Info / 系統信息

None

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

Description

In the ChatGLM implementation, the configuration flag:

apply_query_key_layer_scaling = True

is ignored when running on PyTorch 2.0 or later.

Background

In the original THUDM ChatGLM3 implementation, layer-wise attention scaling is applied:

if self.apply_query_key_layer_scaling:
    coeff = layer_number
    self.norm_factor *= coeff

This effectively scales attention scores by 1 / (sqrt(head_dim) * layer_number), stabilizing deep layers.

Current Behavior
For PyTorch < 2.0, the implementation uses baddbmm with alpha = 1 / norm_factor:

matmul_result = torch.baddbmm(
    ...,
    alpha=1.0 / self.norm_factor
)

In this path, layer-wise scaling works as intended.
For PyTorch ≥ 2.0, the implementation uses torch.nn.functional.scaled_dot_product_attention:

context_layer = torch.nn.functional.scaled_dot_product_attention(
    query_layer, key_layer, value_layer, attention_mask, is_causal=True
)

This function applies a fixed scale of 1 / sqrt(head_dim) and does not accept a layer_number parameter.
As a result, apply_query_key_layer_scaling=True is completely ignored.

Expected behavior / 期待表现

Suggested Fix
• Expose a scaling or norm_factor parameter in the PyTorch 2.0 path
• Multiply by layer_number when apply_query_key_layer_scaling=True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions