chatlawv1 / generate.py
teachyourselfcoding's picture
Update generate.py
817f3cf
raw
history blame
6.37 kB
import sys
import torch
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
import transformers
import gradio as gr
import argparse
import warnings
import os
from utils import StreamPeftGenerationMixin,StreamLlamaForCausalLM
# assert (
# "LlamaTokenizer" in transformers._import_structure["models.llama"]
# ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
parser = argparse.ArgumentParser()
TOT_CUDA="0" #Upgrade bitsandbytes to the latest version to enable balanced loading of multiple GPUs, for example: pip install bitsandbytes==0.39.0
BASE_MODEL="ziqingyang/chinese-llama-2-13b"
LORA_PATH="teachyourselfcoding/llama-2-13b-22sep"
USE_LOCAL=1 # 1: use local model, 0: use huggingface model
TYPE_WRITER=1 # whether output streamly
args = parser.parse_args()
print(args)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
LOAD_8BIT = True
# fix the path for local checkpoint
lora_bin_path = os.path.join(LORA_PATH, "adapter_model.bin")
print(lora_bin_path)
if not os.path.exists(lora_bin_path) and USE_LOCAL:
pytorch_bin_path = os.path.join(LORA_PATH, "pytorch_model.bin")
print(pytorch_bin_path)
if os.path.exists(pytorch_bin_path):
os.rename(pytorch_bin_path, lora_bin_path)
warnings.warn(
"The file name of the lora checkpoint'pytorch_model.bin' is replaced with 'adapter_model.bin'"
)
else:
assert ('Checkpoint is not Found!')
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=LOAD_8BIT,
torch_dtype=torch.float16,
device_map="auto", #device_map={"": 0},
)
model = StreamPeftGenerationMixin.from_pretrained(
model, LORA_PATH, torch_dtype=torch.float16, device_map="auto", #device_map={"": 0}
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
device_map={"": device},
torch_dtype=torch.float16,
)
model = StreamPeftGenerationMixin.from_pretrained(
model,
LORA_PATH,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
)
model = StreamPeftGenerationMixin.from_pretrained(
model,
LORA_PATH,
device_map={"": device},
)
def generate_prompt(instruction, input=None):
if input:
return f"""你是一个乐于助人的中文助手,请你回答一下以下问题
### Instruction:
{instruction}
### Input:
{input}
### Response:"""
else:
return f"""你是一个乐于助人的中文助手,请你回答一下以下问题
### Instruction:
{instruction}
### Response:"""
if not LOAD_8BIT:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
input,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
min_new_tokens=1,
repetition_penalty=2.0,
**kwargs,
):
prompt = generate_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
max_new_tokens=max_new_tokens, # max_length=max_new_tokens+input_sequence
min_new_tokens=min_new_tokens, # min_length=min_new_tokens+input_sequence
**kwargs,
)
with torch.no_grad():
if TYPE_WRITER:
for generation_output in model.stream_generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False,
repetition_penalty=float(repetition_penalty),
):
outputs = tokenizer.batch_decode(generation_output)
show_text = "\n--------------------------------------------\n".join(
[output.split("### Response:")[1].strip().replace('�','')+" ▌" for output in outputs]
)
# if show_text== '':
# yield last_show_text
# else:
yield show_text
yield outputs[0].split("### Response:")[1].strip().replace('�','')
else:
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False,
repetition_penalty=1.3,
)
output = generation_output.sequences[0]
output = tokenizer.decode(output).split("### Response:")[1].strip()
print(output)
yield output
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2, label="Input", placeholder="Tell me about alpacas."
),
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
gr.components.Slider(minimum=1, maximum=10, step=1, value=4, label="Beams Number"),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=256, label="Max New Tokens"
),
gr.components.Slider(
minimum=1, maximum=300, step=1, value=1, label="Min New Tokens"
),
gr.components.Slider(
minimum=0.1, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"
),
],
outputs=[
gr.inputs.Textbox(
lines=25,
label="Output",
)
],
title="HKLawGPT",
description="",
).queue().launch()