ChatWithADoc / app.py
datastx's picture
for gpu
a83266c
raw
history blame contribute delete
No virus
2.15 kB
import os
import gradio as gr
import transformers
from torch import bfloat16
from threading import Thread
from gradio.themes.utils.colors import Color
# Download model and tokenizer files
os.system('bash download_model.sh')
model_id = "/app/medllama2_7b"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model_config = transformers.AutoConfig.from_pretrained(model_id)
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=model_config,
quantization_config=bnb_config,
device_map='auto'
)
prompts = ["You are a helpful AI Doctor."]
def prompt_build(system_prompt, user_inp, hist):
prompt = f"""### System:\n{system_prompt}\n\n"""
for pair in hist:
prompt += f"""### User:\n{pair[0]}\n\n### Assistant:\n{pair[1]}\n\n"""
prompt += f"""### User:\n{user_inp}\n\n### Assistant:"""
return prompt
def chat(user_input, history, system_prompt):
prompt = prompt_build(system_prompt, user_input, history)
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_length=2048,
do_sample=True,
top_p=0.95,
temperature=0.8,
top_k=50
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
return model_output
if __name__ == "__main__":
with gr.Blocks() as demo:
dropdown = gr.Dropdown(choices=prompts, label="Type your own or select a system prompt", value="You are a helpful AI Doctor.", allow_custom_value=True)
chatbot = gr.ChatInterface(fn=chat, additional_inputs=[dropdown])
demo.queue(api_open=False).launch(show_api=False, share=True)