Edit model card

CodeLlaMa模型的中文化版本 (支持多轮对话)

科普:CodeLlaMa是专门用于代码助手的,与ChineseLlaMa不同,适用于代码类问题的回复。
用于多轮对话的推理代码:
(可以直接复制运行,默认会自动拉取该模型权重)

关联Github仓库:https://github.com/CrazyBoyM/CodeLLaMA-chat

# from Firefly
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


def main():
    model_name = 'shareAI/CodeLLaMA-chat-13b-Chinese'

    device = 'cuda'
    max_new_tokens = 500    # 每轮对话最多生成多少个token
    history_max_len = 1000  # 模型记忆的最大token长度
    top_p = 0.9
    temperature = 0.35
    repetition_penalty = 1.0

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map='auto'
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        use_fast=False
    )


    history_token_ids = torch.tensor([[]], dtype=torch.long)

    user_input = input('User:')
    while True:
        input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
        eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
        user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
        history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
        model_input_ids = history_token_ids[:, -history_max_len:].to(device)
        with torch.no_grad():
            outputs = model.generate(
                input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
                temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
            )
        model_input_ids_len = model_input_ids.size(1)
        response_ids = outputs[:, model_input_ids_len:]
        history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
        response = tokenizer.batch_decode(response_ids)
        print("Bot:" + response[0].strip().replace(tokenizer.eos_token, ""))
        user_input = input('User:')


if __name__ == '__main__':
    main()
Downloads last month
859
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train shareAI/CodeLLaMA-chat-13b-Chinese

Spaces using shareAI/CodeLLaMA-chat-13b-Chinese 19