mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
add run gptq
This commit is contained in:
48
README.md
48
README.md
@@ -723,6 +723,54 @@ tokenizer.save_pretrained(new_model_directory)
|
||||
|
||||
Note: For multi-GPU training, you need to specify the proper hyperparameters for distributed training based on your machine. Besides, we advise you to specify your maximum sequence length with the argument `--model_max_length`, based on your consideration of data, memory footprint, and training speed.
|
||||
|
||||
### Quantize Fine-tuned Models
|
||||
|
||||
This section applies to full-parameter/LoRA fine-tuned models. (Note: You do not need to quantize the Q-LoRA fine-tuned model because it is already quantized.)
|
||||
If you use LoRA, please follow the above instructions to merge your model before quantization.
|
||||
|
||||
We recommend using [auto_gptq](https://github.com/PanQiWei/AutoGPTQ) to quantize the finetuned model.
|
||||
|
||||
```bash
|
||||
pip install auto-gptq optimum
|
||||
```
|
||||
|
||||
Note: Currently AutoGPTQ has a bug referred in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/370). Here is a [workaround PR](https://github.com/PanQiWei/AutoGPTQ/pull/495), and you can pull this branch and install from the source.
|
||||
|
||||
First, prepare the calibration data. You can reuse the fine-tuning data, or use other data following the same format.
|
||||
|
||||
Second, run the following script:
|
||||
|
||||
```bash
|
||||
python run_gptq.py \
|
||||
--model_name_or_path $YOUR_LORA_MODEL_PATH \
|
||||
--data_path $DATA \
|
||||
--out_path $OUTPUT_PATH \
|
||||
--bits 4 # 4 for int4; 8 for int8
|
||||
```
|
||||
|
||||
This step requires GPUs and may costs a few hours according to your data size and model size.
|
||||
|
||||
Then, copy all `*.py`, `*.cu`, `*.cpp` files and `generation_config.json` to the output path. And we recommend you to overwrite `config.json` by copying the file from the coresponding official quantized model
|
||||
(for example, if you are fine-tuning `Qwen-7B-Chat` and use `--bits 4`, you can find the `config.json` from [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json)).
|
||||
You should also rename the ``gptq.safetensors`` into ``model.safetensors``.
|
||||
|
||||
Finally, test the model by the same method to load the official quantized model. For example,
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"/path/to/your/model",
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
).eval()
|
||||
|
||||
response, history = model.chat(tokenizer, "你好", history=None)
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Profiling of Memory and Speed
|
||||
We profile the GPU memory and training speed of both LoRA (LoRA (emb) refers to training the embedding and output layer, while LoRA has no trainable embedding and output layer) and Q-LoRA in the setup of single-GPU training. In this test, we experiment on a single A100-SXM4-80G GPU, and we use CUDA 11.8 and Pytorch 2.0. Flash attention 2 is applied. We uniformly use a batch size of 1 and gradient accumulation of 8. We profile the memory (GB) and speed (s/iter) of inputs of different lengths, namely 256, 512, 1024, 2048, 4096, and 8192. We also report the statistics of full-parameter finetuning with Qwen-7B on 2 A100 GPUs. We only report the statistics of 256, 512, and 1024 tokens due to the limitation of GPU memory.
|
||||
|
||||
49
README_CN.md
49
README_CN.md
@@ -713,6 +713,55 @@ tokenizer.save_pretrained(new_model_directory)
|
||||
|
||||
注意:分布式训练需要根据你的需求和机器指定正确的分布式训练超参数。此外,你需要根据你的数据、显存情况和训练速度预期,使用`--model_max_length`设定你的数据长度。
|
||||
|
||||
### 量化微调后模型
|
||||
|
||||
这一小节用于量化全参/LoRA微调后的模型。(注意:你不需要量化Q-LoRA模型因为它本身就是量化过的。)
|
||||
如果你需要量化LoRA微调后的模型,请先根据上方说明去合并你的模型权重。
|
||||
|
||||
我们推荐使用[auto_gptq](https://github.com/PanQiWei/AutoGPTQ)去量化你的模型。
|
||||
|
||||
```bash
|
||||
pip install auto-gptq optimum
|
||||
```
|
||||
|
||||
注意: 当前AutoGPTQ有个bug,可以在该[issue](https://github.com/PanQiWei/AutoGPTQ/issues/370)查看。这里有个[修改PR](https://github.com/PanQiWei/AutoGPTQ/pull/495),你可以使用该分支从代码进行安装。
|
||||
|
||||
首先,准备校准集。你可以重用微调你的数据,或者按照微调相同的方式准备其他数据。
|
||||
|
||||
第二步,运行以下命令:
|
||||
|
||||
```bash
|
||||
python run_gptq.py \
|
||||
--model_name_or_path $YOUR_LORA_MODEL_PATH \
|
||||
--data_path $DATA \
|
||||
--out_path $OUTPUT_PATH \
|
||||
--bits 4 # 4 for int4; 8 for int8
|
||||
```
|
||||
|
||||
这一步需要使用GPU,根据你的校准集大小和模型大小,可能会消耗数个小时。
|
||||
|
||||
接下来, 将原模型中所有 `*.py`, `*.cu`, `*.cpp` 文件和 `generation_config.json` 文件复制到输出模型目录下。同时,使用官方对应版本的量化模型的 `config.json` 文件覆盖输出模型目录下的文件
|
||||
(例如, 如果你微调了 `Qwen-7B-Chat`和`--bits 4`, 那么你可以从 [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json) 仓库中找到对应的`config.json` )。
|
||||
并且,你需要将 ``gptq.safetensors`` 重命名为 ``model.safetensors``。
|
||||
|
||||
最后,像官方量化模型一样测试你的模型。例如:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"/path/to/your/model",
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
).eval()
|
||||
|
||||
response, history = model.chat(tokenizer, "你好", history=None)
|
||||
print(response)
|
||||
```
|
||||
|
||||
### 显存占用及训练速度
|
||||
下面记录7B和14B模型在单GPU使用LoRA(LoRA (emb)指的是embedding和输出层参与训练,而LoRA则不优化这部分参数)和QLoRA时处理不同长度输入的显存占用和训练速度的情况。本次评测运行于单张A100-SXM4-80G GPU,使用CUDA 11.8和Pytorch 2.0,并使用了flash attention 2。我们统一使用batch size为1,gradient accumulation为8的训练配置,记录输入长度分别为256、512、1024、2048、4096和8192的显存占用(GB)和训练速度(s/iter)。我们还使用2张A100测了Qwen-7B的全参数微调。受限于显存大小,我们仅测试了256、512和1024token的性能。
|
||||
|
||||
|
||||
96
run_gptq.py
Normal file
96
run_gptq.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import Dict
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
def preprocess(
|
||||
sources,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_len: int,
|
||||
system_message: str = "You are a helpful assistant."
|
||||
) -> Dict:
|
||||
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
|
||||
|
||||
im_start = tokenizer.im_start_id
|
||||
im_end = tokenizer.im_end_id
|
||||
nl_tokens = tokenizer('\n').input_ids
|
||||
_system = tokenizer('system').input_ids + nl_tokens
|
||||
_user = tokenizer('user').input_ids + nl_tokens
|
||||
_assistant = tokenizer('assistant').input_ids + nl_tokens
|
||||
|
||||
# Apply prompt templates
|
||||
data = []
|
||||
# input_ids, targets = [], []
|
||||
for i, source in enumerate(sources):
|
||||
source = source["conversations"]
|
||||
if roles[source[0]["from"]] != roles["user"]:
|
||||
source = source[1:]
|
||||
|
||||
input_id, target = [], []
|
||||
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
|
||||
input_id += system
|
||||
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
|
||||
assert len(input_id) == len(target)
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
_input_id = tokenizer(role).input_ids + nl_tokens + \
|
||||
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
|
||||
input_id += _input_id
|
||||
if role == '<|im_start|>user':
|
||||
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
|
||||
elif role == '<|im_start|>assistant':
|
||||
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
|
||||
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
|
||||
else:
|
||||
raise NotImplementedError
|
||||
target += _target
|
||||
assert len(input_id) == len(target)
|
||||
input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
|
||||
target = torch.tensor(target[:max_len], dtype=torch.int)
|
||||
data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
|
||||
parser.add_argument("--model_name_or_path", type=str, help="model path")
|
||||
parser.add_argument("--data_path", type=str, help="calibration data path")
|
||||
parser.add_argument("--out_path", type=str, help="output path of the quantized model")
|
||||
parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
|
||||
parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
|
||||
parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
|
||||
args = parser.parse_args()
|
||||
|
||||
quantize_config = BaseQuantizeConfig(
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
damp_percent=0.01,
|
||||
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
|
||||
static_groups=False,
|
||||
sym=True,
|
||||
true_sequential=True,
|
||||
model_name_or_path=None,
|
||||
model_file_base_name="model"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
|
||||
tokenizer.pad_token_id = tokenizer.eod_id
|
||||
data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)
|
||||
|
||||
model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
model.quantize(data, cache_examples_on_gpu=False)
|
||||
|
||||
model.save_quantized(args.out_path, use_safetensors=True)
|
||||
tokenizer.save_pretrained(args.out_path)
|
||||
Reference in New Issue
Block a user