Add multinode finetuning section into README.

This commit is contained in:
苏阳
2023-12-27 11:11:46 +08:00
committed by Ren Xuancheng
parent 174cbde243
commit 35023b6f2a
2 changed files with 118 additions and 34 deletions

View File

@@ -762,9 +762,18 @@ response, history = model.chat(tokenizer, "你好", history=None)
print(response)
```
### 多机微调
我们提供的脚本支持多机微调,可以参考[脚本](./finetune/finetune_lora_ds.sh)中的注释,在每个节点上正确设置相应的参数并启动训练脚本。关于多机分布式训练的更多信息,请参考[torchrun](https://pytorch.org/docs/stable/elastic/run.html)。
注意: DeepSpeed ZeRO 3 对节点间通信速率的要求远大于 ZeRO 2在多机微调的情况下会大幅降低训练速度。因此我们不建议在多机微调的情况下使用 DeepSpeed ZeRO 3 配置。
### 显存占用及训练速度
下面记录7B和14B模型在单GPU使用LoRALoRA (emb)指的是embedding和输出层参与训练而LoRA则不优化这部分参数和QLoRA时处理不同长度输入的显存占用和训练速度的情况。本次评测运行于单张A100-SXM4-80G GPU使用CUDA 11.8和Pytorch 2.0并使用了flash attention 2。我们统一使用batch size为1gradient accumulation为8的训练配置记录输入长度分别为256、512、1024、2048、4096和8192的显存占用GB和训练速度s/iter。我们还使用2张A100测了Qwen-7B的全参数微调。受限于显存大小我们仅测试了256、512和1024token的性能。
对于 Qwen-7B我们额外测试了多机微调的性能。我们在两台服务器上运行评测每台服务器包含两张A100-SXM4-80G GPU其余配置与Qwen-7B的其他评测相同。多机微调的结果在表中以 LoRA (multinode) 标示。
对于 Qwen-72B我们测试了两种方案1使用4个 A100-SXM4-80G GPUs通过 Lora + DeepSpeed ZeRO 3 微调和2使用单张A100-SXM4-80G GPU通过 QLora (int4) 微调。请注意,使用 LoRA (emb) 微调和不带 DeepSpeed ZeRO 3 的 LoRA 微调在4个A100-SXM4-80G GPUs 上都会出现OOM你可以通过将`--deepspeed finetune/ds_config_zero3.json`参数传给[`finetune/finetune_lora_ds.sh`](finetune/finetune_lora_ds.sh)来打开 DeepSpeed ZeRO 3 配置)。
具体数值如下所示:
@@ -772,51 +781,83 @@ print(response)
<table>
<tr>
<th rowspan="2">Model Size</th><th rowspan="2">Method</th><th colspan="6" align="center">Sequence Length</th>
<th rowspan="2">Model Size</th><th rowspan="2">Method</th><th rowspan="2">#Nodes</th><th rowspan="2">#GPUs per node</th><th colspan="6" align="center">Sequence Length</th>
</tr>
<tr>
<th align="center">256</th><th align="center">512</th><th align="center">1024</th><th align="center">2048</th><th align="center">4096</th><th align="center">8192</th>
</tr>
</tr>
<tr>
<th rowspan="4">1.8B</th><td>LoRA</td>
<td>1</td><td>1</td>
<td align="center">6.7G / 1.0s/it</td><td align="center">7.4G / 1.0s/it</td><td align="center">8.4G / 1.1s/it</td><td align="center">11.0G / 1.7s/it</td><td align="center">16.2G / 3.3s/it</td><td align="center">21.8G / 6.8s/it</td>
</tr>
<tr>
<th rowspan="4">1.8B</th><td>LoRA</td><td align="center">6.7G / 1.0s/it</td><td align="center">7.4G / 1.0s/it</td><td align="center">8.4G / 1.1s/it</td><td align="center">11.0G / 1.7s/it</td><td align="center">16.2G / 3.3s/it</td><td align="center">21.8G / 6.8s/it</td>
<td>LoRA (emb)</td>
<td>1</td><td>1</td>
<td align="center">13.7G / 1.0s/it</td><td align="center">14.0G / 1.0s/it</td><td align="center">14.0G / 1.1s/it</td><td align="center">15.1G / 1.8s/it</td><td align="center">19.7G / 3.4s/it</td><td align="center">27.7G / 7.0s/it</td>
</tr>
<tr>
<td>LoRA (emb)</td><td align="center">13.7G / 1.0s/it</td><td align="center">14.0G / 1.0s/it</td><td align="center">14.0G / 1.1s/it</td><td align="center">15.1G / 1.8s/it</td><td align="center">19.7G / 3.4s/it</td><td align="center">27.7G / 7.0s/it</td>
<td>Q-LoRA</td>
<td>1</td><td>1</td>
<td align="center">5.8G / 1.4s/it</td><td align="center">6.0G / 1.4s/it</td><td align="center">6.6G / 1.4s/it</td><td align="center">7.8G / 2.0s/it</td><td align="center">10.2G / 3.4s/it</td><td align="center">15.8G / 6.5s/it</td>
</tr>
<tr>
<td>Q-LoRA</td><td align="center">5.8G / 1.4s/it</td><td align="center">6.0G / 1.4s/it</td><td align="center">6.6G / 1.4s/it</td><td align="center">7.8G / 2.0s/it</td><td align="center">10.2G / 3.4s/it</td><td align="center">15.8G / 6.5s/it</td>
<td>Full-parameter</td>
<td>1</td><td>1</td>
<td align="center">43.5G / 2.1s/it</td><td align="center">43.5G / 2.2s/it</td><td align="center">43.5G / 2.2s/it</td><td align="center">43.5G / 2.3s/it</td><td align="center">47.1G / 2.8s/it</td><td align="center">48.3G / 5.6s/it</td>
</tr>
<tr>
<td>Full-parameter</td><td align="center">43.5G / 2.1s/it</td><td align="center">43.5G / 2.2s/it</td><td align="center">43.5G / 2.2s/it</td><td align="center">43.5G / 2.3s/it</td><td align="center">47.1G / 2.8s/it</td><td align="center">48.3G / 5.6s/it</td>
<th rowspan="5">7B</th>
<td>LoRA</td>
<td>1</td><td>1</td>
<td align="center">20.1G / 1.2s/it</td><td align="center">20.4G / 1.5s/it</td><td align="center">21.5G / 2.8s/it</td><td align="center">23.8G / 5.2s/it</td><td align="center">29.7G / 10.1s/it</td><td align="center">36.6G / 21.3s/it</td>
</tr>
<tr>
<th rowspan="4">7B</th><td>LoRA</td><td align="center">20.1G / 1.2s/it</td><td align="center">20.4G / 1.5s/it</td><td align="center">21.5G / 2.8s/it</td><td align="center">23.8G / 5.2s/it</td><td align="center">29.7G / 10.1s/it</td><td align="center">36.6G / 21.3s/it</td>
<td>LoRA (emb)</td>
<td>1</td><td>1</td>
<td align="center">33.7G / 1.4s/it</td><td align="center">34.1G / 1.6s/it</td><td align="center">35.2G / 2.9s/it</td><td align="center">35.1G / 5.3s/it</td><td align="center">39.2G / 10.3s/it</td><td align="center">48.5G / 21.7s/it</td>
</tr>
<tr>
<td>LoRA (emb)</td><td align="center">33.7G / 1.4s/it</td><td align="center">34.1G / 1.6s/it</td><td align="center">35.2G / 2.9s/it</td><td align="center">35.1G / 5.3s/it</td><td align="center">39.2G / 10.3s/it</td><td align="center">48.5G / 21.7s/it</td>
<td>Q-LoRA</td>
<td>1</td><td>1</td>
<td align="center">11.5G / 3.0s/it</td><td align="center">11.5G / 3.0s/it</td><td align="center">12.3G / 3.5s/it</td><td align="center">13.9G / 7.0s/it</td><td align="center">16.9G / 11.6s/it</td><td align="center">23.5G / 22.3s/it</td>
</tr>
<tr>
<td>Q-LoRA</td><td align="center">11.5G / 3.0s/it</td><td align="center">11.5G / 3.0s/it</td><td align="center">12.3G / 3.5s/it</td><td align="center">13.9G / 7.0s/it</td><td align="center">16.9G / 11.6s/it</td><td align="center">23.5G / 22.3s/it</td>
<td>Full-parameter</td>
<td>1</td><td>2</td>
<td align="center">139.2G / 4.0s/it</td><td align="center">148.0G / 4.0s/it</td><td align="center">162.0G / 4.5s/it</td><td align="center">-</td><td align="center">-</td><td align="center">-</td>
</tr>
<tr>
<td>Full-parameter</td><td align="center">139.2G / 4.0s/it</td><td align="center">148.0G / 4.0s/it</td><td align="center">162.0G / 4.5s/it</td><td align="center">-</td><td align="center">-</td><td align="center">-</td>
<td>LoRA (multinode)</td>
<td>2</td><td>2</td>
<td align="center">74.7G / 2.09s/it</td><td align="center">77.6G / 3.16s/it</td><td align="center">84.9G / 5.17s/it</td><td align="center">95.1G / 9.25s/it</td><td align="center">121.1G / 18.1s/it</td><td align="center">155.5G / 37.4s/it</td>
</tr>
<tr>
<th rowspan="3">14B</th><td>LoRA</td><td align="center">34.6G / 1.6s/it</td><td align="center">35.1G / 2.4s/it</td><td align="center">35.3G / 4.4s/it</td><td align="center">37.4G / 8.4s/it</td><td align="center">42.5G / 17.0s/it</td><td align="center">55.2G / 36.0s/it</td>
<th rowspan="3">14B</th>
<td>LoRA</td>
<td>1</td><td>1</td>
<td align="center">34.6G / 1.6s/it</td><td align="center">35.1G / 2.4s/it</td><td align="center">35.3G / 4.4s/it</td><td align="center">37.4G / 8.4s/it</td><td align="center">42.5G / 17.0s/it</td><td align="center">55.2G / 36.0s/it</td>
</tr>
<tr>
<td>LoRA (emb)</td><td align="center">51.2 / 1.7s/it</td><td align="center">51.1G / 2.6s/it</td><td align="center">51.5G / 4.6s/it</td><td align="center">54.1G / 8.6s/it</td><td align="center">56.8G / 17.2s/it</td><td align="center">67.7G / 36.3s/it</td>
<td>LoRA (emb)</td>
<td>1</td><td>1</td>
<td align="center">51.2 / 1.7s/it</td><td align="center">51.1G / 2.6s/it</td><td align="center">51.5G / 4.6s/it</td><td align="center">54.1G / 8.6s/it</td><td align="center">56.8G / 17.2s/it</td><td align="center">67.7G / 36.3s/it</td>
</tr>
<tr>
<td>Q-LoRA</td><td align="center">18.7G / 5.3s/it</td><td align="center">18.4G / 6.3s/it</td><td align="center">18.9G / 8.2s/it</td><td align="center">19.9G / 11.8s/it</td><td align="center">23.0G / 20.1s/it</td><td align="center">27.9G / 38.3s/it</td>
<td>Q-LoRA</td>
<td>1</td><td>1</td>
<td align="center">18.7G / 5.3s/it</td><td align="center">18.4G / 6.3s/it</td><td align="center">18.9G / 8.2s/it</td><td align="center">19.9G / 11.8s/it</td><td align="center">23.0G / 20.1s/it</td><td align="center">27.9G / 38.3s/it</td>
</tr>
<tr>
<th rowspan="2">72B</th><td>LoRA + Deepspeed Zero3</td><td align="center">215.4G / 17.6s/it</td><td align="center">217.7G / 20.5s/it</td><td align="center">222.6G / 29.4s/it</td><td align="center">228.8G / 45.7s/it</td><td align="center">249.0G / 83.4s/it</td><td align="center">289.2G / 161.5s/it</td>
<th rowspan="2">72B</th>
<td>LoRA + Deepspeed Zero3</td>
<td>1</td><td>4</td>
<td align="center">215.4G / 17.6s/it</td><td align="center">217.7G / 20.5s/it</td><td align="center">222.6G / 29.4s/it</td><td align="center">228.8G / 45.7s/it</td><td align="center">249.0G / 83.4s/it</td><td align="center">289.2G / 161.5s/it</td>
</tr>
<tr>
<td>Q-LoRA</td><td align="center">61.4G / 27.4s/it</td><td align="center">61.4G / 31.5s/it</td><td align="center">62.9G / 41.4s/it</td><td align="center">64.1G / 59.5s/it</td><td align="center">68.0G / 97.7s/it</td><td align="center">75.6G / 179.8s/it</td>
<td>Q-LoRA</td>
<td>1</td><td>1</td>
<td align="center">61.4G / 27.4s/it</td><td align="center">61.4G / 31.5s/it</td><td align="center">62.9G / 41.4s/it</td><td align="center">64.1G / 59.5s/it</td><td align="center">68.0G / 97.7s/it</td><td align="center">75.6G / 179.8s/it</td>
</tr>
</table>