在增量预训练阶段或有监督微调阶段使用高效微调方法(Lora)时会产生adapter文件,相当于是一个“补丁”。那么如何将“补丁”与原始模型合并呢?
下面将对模型合并代码进行解读。
相关代码将全部上传到github:
https://github.com/hjandlm/LLM_Train
代码解读import argparse
from loguru import logger
import torch
from peft import PeftModel, PeftConfig
from transformers import (
AutoModel,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModelForSequenceClassification,
)
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
确定模型类型调用不同包加载模型和分词器。
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, required=True, type=str,
help="Base model name or path")
parser.add_argument('--tokenizer_path', default=None, type=str,
help="Please specify tokenization path.")
parser.add_argument('--lora_model', default=None, required=True, type=str,
help="Please specify LoRA model to be merged.")
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--output_dir', default='./merged', type=str)
args = parser.parse_args()
logger.info(f"merged_args:{args}")
超参数包括模型类型、基础模型、分词器路径、lora模型、是否修改模型词表大小、输出目录。
tokenizer_path是在词表扩充后设置,否则使用原始词表。
resize_emb是在进行扩充词表后才会使用。
base_model_path = args.base_model
lora_model_path = args.lora_model
output_dir = args.output_dir
peft_config = PeftConfig.from_pretrained(lora_model_path)
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
# 模型加载
if peft_config.task_type == "SEQ_CLS":
logger.info("Loading LoRA for sequence classification model")
if args.model_type == "chatglm":
raise ValueError("chatglm does not support sequence classification")
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
num_labels=1,
load_in_8bit=False,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
)
else:
logger.info("Loading LoRA for causal language model")
base_model = model_class.from_pretrained(
base_model_path,
load_in_8bit=False,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
# 分词器加载
if args.tokenizer_path:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
else:
tokenizer = tokenizer_class.from_pretrained(base_model_path, trust_remote_code=True)
# 修改词表大小
if args.resize_emb:
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
if base_model_token_size != len(tokenizer):
base_model.resize_token_embeddings(len(tokenizer))
logger.info(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
# 初始化Peft新模型
new_model = PeftModel.from_pretrained(
base_model,
lora_model_path,
device_map="auto",
torch_dtype=torch.float16,
)
new_base_model = new_model.merge_and_unload()
tokenizer.save_pretrained(output_dir)
new_base_model.save_pretrained(output_dir, safe_serialization=False,max_shard_size='10GB')
safe_serialization:指定是否将模型权重转换为safetensors格式以进行更安全的序列化。默认是True,指定为False,为hf格式,否则是safetensors格式。
max_shard_size:控制模型最大分片大小。
完整代码import argparse
from loguru import logger
import torch
from peft import PeftModel, PeftConfig
from transformers import (
AutoModel,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModelForSequenceClassification,
)
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, required=True, type=str,
help="Base model name or path")
parser.add_argument('--tokenizer_path', default=None, type=str,
help="Please specify tokenization path.")
parser.add_argument('--lora_model', default=None, required=True, type=str,
help="Please specify LoRA model to be merged.")
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--output_dir', default='./merged', type=str)
args = parser.parse_args()
logger.info(f"merged_args:{args}")
base_model_path = args.base_model
lora_model_path = args.lora_model
output_dir = args.output_dir
peft_config = PeftConfig.from_pretrained(lora_model_path)
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
# 模型加载
if peft_config.task_type == "SEQ_CLS":
logger.info("Loading LoRA for sequence classification model")
if args.model_type == "chatglm":
raise ValueError("chatglm does not support sequence classification")
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
num_labels=1,
load_in_8bit=False,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
)
else:
logger.info("Loading LoRA for causal language model")
base_model = model_class.from_pretrained(
base_model_path,
load_in_8bit=False,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
# 分词器加载
if args.tokenizer_path:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
else:
tokenizer = tokenizer_class.from_pretrained(base_model_path, trust_remote_code=True)
# 修改词表大小
if args.resize_emb:
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
if base_model_token_size != len(tokenizer):
base_model.resize_token_embeddings(len(tokenizer))
logger.info(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
# 初始化Peft新模型
new_model = PeftModel.from_pretrained(
base_model,
lora_model_path,
device_map="auto",
torch_dtype=torch.float16,
)
new_model.eval()
logger.info(f"Merging with merge_and_unload...")
new_base_model = new_model.merge_and_unload()
logger.info("Saving to Hugging Face format...")
tokenizer.save_pretrained(output_dir)
new_base_model.save_pretrained(output_dir, safe_serialization=False,max_shard_size='10GB')
logger.info(f"Done! model saved to {output_dir}")
if __name__ == '__main__':
main()
运行结果
扩充词表运行文件merge_pt.sh:
python merge_peft_adapter.py \
--model_type llama \
--base_model llama-2-7b-bin \
--resize_emb \
--tokenizer_path pt_lora_model \
--lora_model pt_lora_model \
--output_dir ./pt/model
不扩充词表运行merge_sft.sh:
python merge_peft_adapter.py \
--model_type llama \
--base_model llama-2-7b-bin \
--lora_model pt_lora_model \
--output_dir ./sft/model
运行结果:
设置safe_serialization=True,运行结果:
参考[1] https://github.com/huggingface/peft/tree/b4faffea8ae031e5bd69a76b55418b3650c04c80
[2] https://github.com/shibing624/MedicalGPT/blob/main/merge_peft_adapter.py
Copyright © 2024 妖气游戏网 www.17u1u.com All Rights Reserved