Edit model card

基于Llama2_7B的藏语词汇表扩充,继续预训练的Yak模型

一、CPT 阶段,本文采取两阶段方式

  • 1.1 第一阶段,固定模型Transformer 部分的参数,仅训练Embedding,在尽量不干扰原模型的情况下适配新增的藏文词向量;
  • 1.2 第二阶段:为模型添加LoRA+ 权重,训练Embedding 的同时也更新LoRA+ 参数。

两阶段的训练方式虽然效率较低,然而有效缓解了由于藏文数据与Llama 2 模型预训练时使用的数据分布存在差距而在CPT 过程中出现分布偏移的问题。

二、本文的训练流程主要包含

  • 2.1 对Llama 2 进行藏文词表扩充,词表由32000 扩展至56724,提高模型在藏文的编解码效率。
  • 2.2 在TibetanGeneralCorpus 上使用Sentencepiece 工具训练基于Unigram 策略的藏文分词器。生成的词表与原版Llama 2 的32K 词表进行合并,排除重复的词元后,得到扩充后词表规模为56724。用15G 的TibetanGeneralCorpus 和20G 的英、中混合文本进行CPT,采用自回归任务。

三、加载模型并启动服务

# -*- coding: UTF-8 -*-
#
"""
功能为:主要用于调用shajiu/Yak_Llama2_7B

@File:  llama2-7b-server.py
@Software:  PyCharm
"""
import json
import logging
logging.basicConfig(
    level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

from flask import Flask
from flask import Response
from flask import request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer

app = Flask(__name__)
CORS(app)
app.logger.setLevel(logging.INFO)



def load_model(model_name):
    # 加载模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return tokenizer, model

def generate_response(model, tokenizer, text):
    # 对输入的文本进行编码
    inputs = tokenizer.encode(text, return_tensors='pt')

    # 使用模型生成响应
    output = model.generate(inputs, max_length=50, num_return_sequences=1)

    # 对生成的输出进行解码,获取生成的文本
    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded_output



@app.route('/api/chat', methods=['POST'])
def qtpdnn_v0():
    """Description"""
    inputs = request.get_json()
    response = generate_response(model, tokenizer, inputs.get("query"))
    print("输出",response)
    output=inputs
    output.update({"output":response})
    return Response(json.dumps(output, ensure_ascii=False), mimetype='application/json')


if __name__ == "__main__":
    # 模型名称
    model_name = 'shajiu/Yak_Llama2_7B'
    # 加载模型
    tokenizer, model = load_model(model_name)
    app.run(host='0.0.0.0', port=8718, debug=False, threaded=False, processes=1)
Downloads last month
19
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for shajiu/Yak_Llama2_7B

Quantizations
1 model