dolphin / app.py
nroggendorff's picture
Update app.py
757022d verified
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
@spaces.GPU(duration=120)
def predict(message, history):
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
history_transformer_format = history + [[message, ""]]
system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=10000,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message:
break
yield partial_message
gr.ChatInterface(predict).launch()