Skip to content

Continuous codebook in HYB is not updated during training.  #19

Description

@badeok0716

In Page 8 of the QTIP paper, the authors state:

Important

This codebook is differentiable, so we can finetune it: to evaluate this, we fine-tune using QuIP#’s methodology, tuning both the codebook entries and the as-yet-unquantized weights in a blockwise fashion.

And, according to the instructions in this repo, --ft_train_lut flag in both finetune_e2e_llama.py and quantize_finetune_llama.py is intended to enable fine-tuning of codebook entries.

Issue:

However, I observed that gradients do not flow into the tlut parameters in the QuantizedLinear layer layer. This indicates that the current implementation does not actually update the codebook parameters during fine-tuning.

In the forward pass of QuantizedLinear layer, the code passes BitShiftLinearKernelAG for "decompress and matmul" during fine-tuning, whose backward function does not return grad of tlut. Please refer to the following minimal code for reproduce reported behavior in backward.

Reproducing the Behavior:

from lib.utils.unsafe_import import model_from_hf_path
import torch

def get_quantized_layer(path2qmodel="YOURPATH"): 
    # load a quantized model
    quant_model = model_from_hf_path(path2qmodel)[0].float()
    # select an arbitrary layer.
    quantized_layer = quant_model.model.layers[0].self_attn.q_proj

    # replicate the routine in finetune_e2e_llama.py #L95:107 with --ft_train_lut flag 
    quantized_layer.SU = torch.nn.Parameter(quantized_layer.SU.float(), requires_grad=True)
    quantized_layer.SV = torch.nn.Parameter(quantized_layer.SV.float(), requires_grad=True)
    quantized_layer.mode = "train-recons"
    quantized_layer.tlut.requires_grad = True

    return quantized_layer

def test_backward():
    # load quantized layer
    quantized_layer = get_quantized_layer()

    # initialize random input to the layer
    ft_bs, ctx_size, in_features = 4, 4096, 4096
    input = torch.randn(ft_bs, ctx_size, in_features).to('cuda').to(torch.float16)
    input.requires_grad = True

    print("=== Before backward ===")
    print("input", input.grad)
    print("SU", quantized_layer.SU.grad)
    print("SV", quantized_layer.SV.grad)
    print("tlut", quantized_layer.tlut.grad)

    # forward pass
    output = quantized_layer(input)

    # backward pass
    loss = output.sum()
    loss.backward()

    print("=== After backward ===")
    print("input", input.grad)
    print("SU", quantized_layer.SU.grad)
    print("SV", quantized_layer.SV.grad)
    print("tlut", quantized_layer.tlut.grad)

if __name__ == "__main__":
    test_backward()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions