Thanks for your great work!
Following the below commands, I got poor PPL(about 127.x) for llama2-7b on wikitext2 and c4. But for llama3.1-8b-instruct, llama3.2-1b/3b-instruct, the PPL is good as the paper reported. Can you help me with this?
1. get hessian
torchrun --standalone --nproc-per-node=8 hessian_llama/get_hess_llama.py \
--save_path hessian_path \
--orig_model $BASE_MODEL \
--batch_size 32 \
--hessian_sketch B \
--power_iters 1 \
--ctx_size 2048 \
--n_seqs 65536 \
--fp64_accum
2. quantize , finetune, eval
python -m quantize_llama.quantize_finetune_llama \
--save_path $QUANTIZED_WEIGHTS_PATH \
--base_model $BASE_MODEL \
--hess_path $HESS_PATH \
--codebook bitshift \
--scale_override 0.9 \
--ft_epochs 5 \
--td_x 16 \
--td_y 16 \
--L 16 \
--K 2 \
--V 2 \
--decode_mode quantlut_sym \
--tlut_bits 9 \
>> $LOG_DIR/quantize_finetune_llama.log 2>&1
# convert the quantized model to a hf model
python -m quantize_llama.hfize_llama --quantized_path $QUANTIZED_WEIGHTS_PATH --hf_output_path $QUANTIZED_HF_WEIGHTS_PATH >> $LOG_DIR/hfize_llama.log 2>&1
# do end to end finetuning
python -m quantize_llama.finetune_e2e_llama --base_model $BASE_MODEL --hf_path $QUANTIZED_HF_WEIGHTS_PATH --devset_size 640 --ft_valid_size 128 --ft_epochs 4 --ft_update_freq 4 --ft_bs 2 --ctx_size 4096 --ft_train_lut --hf_output_path $QUANTIZED_HF_FT_WEIGHTS_PATH >> $LOG_DIR/finetune_e2e_llama.log 2>&1
# evaluate perplexity and zeroshot results
python -m eval.eval_ppl --hf_path $QUANTIZED_HF_WEIGHTS_PATH --tokenizer $BASE_MODEL >> $LOG_DIR/eval_ppl.log 2>&1
python -m eval.eval_zeroshot --tasks arc_challenge,arc_easy,boolq,piqa,winogrande --tokenizer $BASE_MODEL --batch_size 16 --hf_path $QUANTIZED_HF_WEIGHTS_PATH >> $LOG_DIR/eval_zeroshot.log 2>&1
# evaluate perplexity and zeroshot results
python -m eval.eval_ppl --hf_path $QUANTIZED_HF_FT_WEIGHTS_PATH --tokenizer $BASE_MODEL >> $LOG_DIR/eval_ppl_ft.log 2>&1
python -m eval.eval_zeroshot --tasks arc_challenge,arc_easy,boolq,piqa,winogrande --tokenizer $BASE_MODEL --batch_size 16 --hf_path $QUANTIZED_HF_FT_WEIGHTS_PATH >> $LOG_DIR/eval_zeroshot_ft.log 2>&1
Thanks for your great work!
Following the below commands, I got poor PPL(about 127.x) for llama2-7b on wikitext2 and c4. But for llama3.1-8b-instruct, llama3.2-1b/3b-instruct, the PPL is good as the paper reported. Can you help me with this?
1. get hessian
torchrun --standalone --nproc-per-node=8 hessian_llama/get_hess_llama.py \ --save_path hessian_path \ --orig_model $BASE_MODEL \ --batch_size 32 \ --hessian_sketch B \ --power_iters 1 \ --ctx_size 2048 \ --n_seqs 65536 \ --fp64_accum2. quantize , finetune, eval
python -m quantize_llama.quantize_finetune_llama \ --save_path $QUANTIZED_WEIGHTS_PATH \ --base_model $BASE_MODEL \ --hess_path $HESS_PATH \ --codebook bitshift \ --scale_override 0.9 \ --ft_epochs 5 \ --td_x 16 \ --td_y 16 \ --L 16 \ --K 2 \ --V 2 \ --decode_mode quantlut_sym \ --tlut_bits 9 \ >> $LOG_DIR/quantize_finetune_llama.log 2>&1 # convert the quantized model to a hf model python -m quantize_llama.hfize_llama --quantized_path $QUANTIZED_WEIGHTS_PATH --hf_output_path $QUANTIZED_HF_WEIGHTS_PATH >> $LOG_DIR/hfize_llama.log 2>&1 # do end to end finetuning python -m quantize_llama.finetune_e2e_llama --base_model $BASE_MODEL --hf_path $QUANTIZED_HF_WEIGHTS_PATH --devset_size 640 --ft_valid_size 128 --ft_epochs 4 --ft_update_freq 4 --ft_bs 2 --ctx_size 4096 --ft_train_lut --hf_output_path $QUANTIZED_HF_FT_WEIGHTS_PATH >> $LOG_DIR/finetune_e2e_llama.log 2>&1 # evaluate perplexity and zeroshot results python -m eval.eval_ppl --hf_path $QUANTIZED_HF_WEIGHTS_PATH --tokenizer $BASE_MODEL >> $LOG_DIR/eval_ppl.log 2>&1 python -m eval.eval_zeroshot --tasks arc_challenge,arc_easy,boolq,piqa,winogrande --tokenizer $BASE_MODEL --batch_size 16 --hf_path $QUANTIZED_HF_WEIGHTS_PATH >> $LOG_DIR/eval_zeroshot.log 2>&1 # evaluate perplexity and zeroshot results python -m eval.eval_ppl --hf_path $QUANTIZED_HF_FT_WEIGHTS_PATH --tokenizer $BASE_MODEL >> $LOG_DIR/eval_ppl_ft.log 2>&1 python -m eval.eval_zeroshot --tasks arc_challenge,arc_easy,boolq,piqa,winogrande --tokenizer $BASE_MODEL --batch_size 16 --hf_path $QUANTIZED_HF_FT_WEIGHTS_PATH >> $LOG_DIR/eval_zeroshot_ft.log 2>&1