Spaces:
Paused
Paused
File size: 5,368 Bytes
6ebf55c 3fa0f4e 6ebf55c 3fa0f4e 6ebf55c 3fa0f4e 6ebf55c 3fa0f4e 6ebf55c 3fa0f4e 6ebf55c 3fa0f4e 6ebf55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import sys
import torch
import argparse
from peft import PeftModel
import transformers
from collections import namedtuple
from transformers import (
LlamaForCausalLM, LlamaTokenizer,
AutoModel, AutoTokenizer,
BloomForCausalLM, BloomTokenizerFast, GenerationConfig)
tokenizer=None
model=None
LOAD_8BIT = False
ModelClass = namedtuple("ModelClass", ('tokenizer', 'model'))
_MODEL_CLASSES = {
"llama": ModelClass(**{
"tokenizer": LlamaTokenizer,
"model": LlamaForCausalLM,
}),
"bloom": ModelClass(**{
"tokenizer": AutoTokenizer,
"model": BloomForCausalLM,
})
}
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def get_model_class(model_type,
model_name_or_path,
lora_model_path):
global model, tokenizer
model_class = _MODEL_CLASSES[model_type] # tokenizer, model
model_base = model_class.model.from_pretrained(model_name_or_path,
load_in_8bit=LOAD_8BIT,
torch_dtype=torch.float16,
device_map="auto",
)
tokenizer = model_class.tokenizer.from_pretrained(model_name_or_path) # default add_eos_token=False
model = PeftModel.from_pretrained(
model_base,
lora_model_path,
torch_dtype=torch.float16,
)
if not LOAD_8BIT:
model.half()
def predict(
instruction,
top_p=0.75,
temperature=0.1,
history=None,
top_k=40,
num_beams=4,
max_new_tokens=512,
**kwargs,
):
history = history or []
prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{0}\n\n### Response:"
).format(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print('樑εεε€', output)
bot_response = output.split("### Response:")[1].strip()
history.append((instruction, bot_response))
return "", history, history
def predict_test(message, top_p, temperature, history):
history = history or []
user_message = f"{message} {top_p}, {temperature}"
print(user_message)
history.append((message, user_message))
return history, history
def clear_session():
return '', '', None
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--size', default=7, type=int, help='the size of llama model')
parser.add_argument('--data', default="", type=str, help='the data used for instructing tuning')
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--model_type', default="llama", choices=['llama', 'chatglm', 'bloom'])
parser.add_argument('--model_name_or_path', default="decapoda-research/llama-7b-hf", type=str)
parser.add_argument('--lora_name_or_path', default="./saved-alpaca-belle-cot7b", type=str)
args = parser.parse_args()
get_model_class(args.model_type, args.model_name_or_path, args.lora_name_or_path)
block = gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""")
with block as demo:
#top_p, temperature
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.75, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.1, step=0.1, interactive=True, label="Temperature",)
chatbot = gr.Chatbot(label="Alpaca-CoT")
message = gr.Textbox()
state = gr.State()
message.submit(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state], queue=False)
with gr.Row():
clear_history = gr.Button("π ζΈ
ι€εε²ε―Ήθ― | Clear History")
clear = gr.Button('π§Ή ζΈ
ι€θΎε
₯ | Clear Input')
send = gr.Button("π ει | Send")
regenerate = gr.Button("π ιζ°ηζ | regenerate")
# regenerate.click(regenerate, inputs=[message], outputs=[chatbot])
regenerate.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False)
send.click(predict, inputs=[message, top_p, temperature, state], outputs=[message, chatbot, state])
clear.click(lambda: None, None, message, queue=False)
clear_history.click(fn=clear_session , inputs=[], outputs=[message, chatbot, state], queue=False)
demo.queue(max_size=20, concurrency_count=20).launch() |