|
import torch |
|
from transformers import ( |
|
LlamaForCausalLM, |
|
LlamaTokenizer, |
|
StoppingCriteria, |
|
) |
|
import gradio as gr |
|
import argparse |
|
import os |
|
from queue import Queue |
|
from threading import Thread |
|
import traceback |
|
import gc |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--base_model', |
|
default=None, |
|
type=str, |
|
required=True, |
|
help='Base model path') |
|
parser.add_argument('--lora_model', default=None, type=str, |
|
help="If None, perform inference on the base model") |
|
parser.add_argument( |
|
'--tokenizer_path', |
|
default=None, |
|
type=str, |
|
help='If None, lora model path or base model path will be used') |
|
parser.add_argument( |
|
'--gpus', |
|
default="0", |
|
type=str, |
|
help='If None, cuda:0 will be used. Inference using multi-cards: --gpus=0,1,... ') |
|
parser.add_argument('--share', default=True, help='Share gradio domain name') |
|
parser.add_argument('--port', default=19327, type=int, help='Port of gradio demo') |
|
parser.add_argument( |
|
'--max_memory', |
|
default=256, |
|
type=int, |
|
help='Maximum input prompt length, if exceeded model will receive prompt[-max_memory:]') |
|
parser.add_argument( |
|
'--load_in_8bit', |
|
action='store_true', |
|
help='Use 8 bit quantified model') |
|
parser.add_argument( |
|
'--only_cpu', |
|
action='store_true', |
|
help='Only use CPU for inference') |
|
parser.add_argument( |
|
'--alpha', |
|
type=str, |
|
default="1.0", |
|
help="The scaling factor of NTK method, can be a float or 'auto'. ") |
|
args = parser.parse_args() |
|
if args.only_cpu is True: |
|
args.gpus = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
|
|
|
|
|
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
def setup(): |
|
global tokenizer, model, device, share, port, max_memory |
|
max_memory = args.max_memory |
|
port = args.port |
|
share = args.share |
|
load_in_8bit = args.load_in_8bit |
|
load_type = torch.float16 |
|
if torch.cuda.is_available(): |
|
device = torch.device(0) |
|
else: |
|
device = torch.device('cpu') |
|
if args.tokenizer_path is None: |
|
args.tokenizer_path = args.lora_model |
|
if args.lora_model is None: |
|
args.tokenizer_path = args.base_model |
|
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path,legacy=False) |
|
|
|
base_model = LlamaForCausalLM.from_pretrained( |
|
args.base_model, |
|
load_in_8bit=load_in_8bit, |
|
torch_dtype=load_type, |
|
low_cpu_mem_usage=True, |
|
device_map='auto', |
|
) |
|
|
|
model_vocab_size = base_model.get_input_embeddings().weight.size(0) |
|
tokenzier_vocab_size = len(tokenizer) |
|
print(f"Vocab of the base model: {model_vocab_size}") |
|
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") |
|
if model_vocab_size != tokenzier_vocab_size: |
|
assert tokenzier_vocab_size > model_vocab_size |
|
print("Resize model embeddings to fit tokenizer") |
|
base_model.resize_token_embeddings(tokenzier_vocab_size) |
|
if args.lora_model is not None: |
|
print("loading peft model") |
|
model = PeftModel.from_pretrained( |
|
base_model, |
|
args.lora_model, |
|
torch_dtype=load_type, |
|
device_map='auto', |
|
) |
|
else: |
|
model = base_model |
|
|
|
if device == torch.device('cpu'): |
|
model.float() |
|
|
|
model.eval() |
|
|
|
|
|
|
|
def reset_user_input(): |
|
return gr.update(value='') |
|
|
|
|
|
|
|
def reset_state(): |
|
return [] |
|
|
|
|
|
|
|
def generate_prompt(instruction): |
|
return f""" |
|
Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
{instruction} |
|
""" |
|
|
|
|
|
|
|
def user(user_message, history): |
|
return gr.update(value="", interactive=False), history + \ |
|
[[user_message, None]] |
|
|
|
|
|
class Stream(StoppingCriteria): |
|
def __init__(self, callback_func=None): |
|
self.callback_func = callback_func |
|
|
|
def __call__(self, input_ids, scores) -> bool: |
|
if self.callback_func is not None: |
|
self.callback_func(input_ids[0]) |
|
return False |
|
|
|
|
|
class Iteratorize: |
|
""" |
|
Transforms a function that takes a callback |
|
into a lazy iterator (generator). |
|
|
|
Adapted from: https://stackoverflow.com/a/9969000 |
|
""" |
|
def __init__(self, func, kwargs=None, callback=None): |
|
self.mfunc = func |
|
self.c_callback = callback |
|
self.q = Queue() |
|
self.sentinel = object() |
|
self.kwargs = kwargs or {} |
|
self.stop_now = False |
|
|
|
def _callback(val): |
|
if self.stop_now: |
|
raise ValueError |
|
self.q.put(val) |
|
|
|
def gentask(): |
|
try: |
|
ret = self.mfunc(callback=_callback, **self.kwargs) |
|
except ValueError: |
|
pass |
|
except Exception: |
|
traceback.print_exc() |
|
|
|
clear_torch_cache() |
|
self.q.put(self.sentinel) |
|
if self.c_callback: |
|
self.c_callback(ret) |
|
|
|
self.thread = Thread(target=gentask) |
|
self.thread.start() |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
obj = self.q.get(True, None) |
|
if obj is self.sentinel: |
|
raise StopIteration |
|
else: |
|
return obj |
|
|
|
def __del__(self): |
|
clear_torch_cache() |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.stop_now = True |
|
clear_torch_cache() |
|
|
|
|
|
def clear_torch_cache(): |
|
gc.collect() |
|
if torch.cuda.device_count() > 0: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def predict( |
|
history, |
|
max_new_tokens=128, |
|
top_p=0.75, |
|
temperature=0.1, |
|
top_k=40, |
|
do_sample=True, |
|
repetition_penalty=1.0 |
|
): |
|
history[-1][1] = "" |
|
if len(history) != 0: |
|
input = "".join(["### Instruction:\n" + |
|
i[0] + |
|
"\n\n" + |
|
"### Response: " + |
|
i[1] + |
|
("\n\n" if i[1] != "" else "") for i in history]) |
|
if len(input) > max_memory: |
|
input = input[-max_memory:] |
|
prompt = generate_prompt(input) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
|
|
generate_params = { |
|
'input_ids': input_ids, |
|
'max_new_tokens': max_new_tokens, |
|
'top_p': top_p, |
|
'temperature': temperature, |
|
'top_k': top_k, |
|
'do_sample': do_sample, |
|
'repetition_penalty': repetition_penalty, |
|
} |
|
|
|
def generate_with_callback(callback=None, **kwargs): |
|
if 'stopping_criteria' in kwargs: |
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback)) |
|
else: |
|
kwargs['stopping_criteria'] = [Stream(callback_func=callback)] |
|
clear_torch_cache() |
|
with torch.no_grad(): |
|
model.generate(**kwargs, pad_token_id=tokenizer.eos_token_id) |
|
|
|
def generate_with_streaming(**kwargs): |
|
return Iteratorize(generate_with_callback, kwargs, callback=None) |
|
|
|
with generate_with_streaming(**generate_params) as generator: |
|
for output in generator: |
|
next_token_ids = output[len(input_ids[0]):] |
|
if next_token_ids[0] == tokenizer.eos_token_id: |
|
break |
|
new_tokens = tokenizer.decode( |
|
next_token_ids, skip_special_tokens=True) |
|
if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0: |
|
if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): |
|
new_tokens = ' ' + new_tokens |
|
|
|
history[-1][1] = new_tokens |
|
yield history |
|
if len(next_token_ids) >= max_new_tokens: |
|
break |
|
|
|
|
|
|
|
setup() |
|
|
|
|
|
|
|
from gradio import Blocks |
|
with gr.Blocks() as demo: |
|
gr.Markdown("<center><div style='margin-top: 21px; margin-bottom: 8px; font-size: 44px; font-weight: bold;'>Llama2</div></center>") |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Column(scale=12): |
|
user_input = gr.Textbox( |
|
show_label=False, |
|
placeholder="Shift + Enter发送消息...", |
|
lines=10, |
|
container=False) |
|
with gr.Column(min_width=32, scale=1): |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
submitBtn = gr.Button("Submit", variant="primary") |
|
with gr.Column(scale=1): |
|
emptyBtn = gr.Button("Clear History") |
|
|
|
max_new_token = gr.Slider( |
|
0, |
|
4096, |
|
value=2048, |
|
step=1.0, |
|
label="Maximum New Token Length", |
|
interactive=True, |
|
visible=False) |
|
top_p = gr.Slider(0, 1, value=0.9, step=0.01, |
|
label="Top P", interactive=True, visible=False) |
|
temperature = gr.Slider( |
|
0, |
|
1, |
|
value=0.5, |
|
step=0.01, |
|
label="Temperature", |
|
interactive=True, |
|
visible=False) |
|
top_k = gr.Slider(1, 40, value=40, step=1, |
|
label="Top K", interactive=True, visible=False) |
|
do_sample = gr.Checkbox( |
|
value=True, |
|
label="Do Sample", |
|
info="use random sample strategy", |
|
interactive=True, |
|
visible=False) |
|
repetition_penalty = gr.Slider( |
|
1.0, |
|
3.0, |
|
value=1.1, |
|
step=0.1, |
|
label="Repetition Penalty", |
|
interactive=True, |
|
visible=False) |
|
|
|
params = [user_input, chatbot] |
|
predict_params = [ |
|
chatbot, |
|
max_new_token, |
|
top_p, |
|
temperature, |
|
top_k, |
|
do_sample, |
|
repetition_penalty |
|
] |
|
|
|
submitBtn.click( |
|
user, |
|
params, |
|
params, |
|
queue=False).then( |
|
predict, |
|
predict_params, |
|
chatbot).then( |
|
lambda: gr.update( |
|
interactive=True), |
|
None, |
|
[user_input], |
|
queue=False) |
|
|
|
user_input.submit( |
|
user, |
|
params, |
|
params, |
|
queue=False).then( |
|
predict, |
|
predict_params, |
|
chatbot).then( |
|
lambda: gr.update( |
|
interactive=True), |
|
None, |
|
[user_input], |
|
queue=False) |
|
|
|
submitBtn.click(reset_user_input, [], [user_input]) |
|
|
|
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) |
|
|
|
|
|
|
|
demo.queue().launch( |
|
share=share, |
|
inbrowser=True, |
|
server_name='0.0.0.0', |
|
server_port=port) |
|
|