Spaces:
Runtime error
Runtime error
# %% | |
import os, json, itertools, bisect, gc | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
import transformers | |
import torch | |
from accelerate import Accelerator | |
import accelerate | |
import time | |
import os | |
import gradio as gr | |
import requests | |
import random | |
import googletrans | |
translator = googletrans.Translator() | |
model = None | |
tokenizer = None | |
generator = None | |
os.environ["CUDA_VISIBLE_DEVICES"]="" | |
def load_model(model_name, eight_bit=0, device_map="auto"): | |
global model, tokenizer, generator | |
print("Loading "+model_name+"...") | |
if device_map == "zero": | |
device_map = "balanced_low_0" | |
# config | |
gpu_count = torch.cuda.device_count() | |
print('gpu_count', gpu_count) | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
print(model_name) | |
tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name) | |
model = transformers.LLaMAForCausalLM.from_pretrained( | |
model_name, | |
#device_map=device_map, | |
#device_map="auto", | |
torch_dtype=torch_dtype, | |
#max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"}, | |
#load_in_8bit=eight_bit, | |
#from_tf=True, | |
low_cpu_mem_usage=True, | |
load_in_8bit=False, | |
cache_dir="cache" | |
) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
else: | |
model = model.cpu() | |
generator = model.generate | |
# chat doctor | |
def chatdoctor(input, state): | |
# print('input',input) | |
# history = history or [] | |
print('state',state) | |
invitation = "ChatDoctor: " | |
human_invitation = "Patient: " | |
fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n" | |
for i in range(len(state)): | |
if i % 2: | |
fulltext += human_invitation + state[i] + "\n\n" | |
else: | |
fulltext += invitation + state[i] + "\n\n" | |
fulltext += human_invitation + input + "\n\n" | |
fulltext += invitation | |
print('fulltext: ',fulltext) | |
generated_text = "" | |
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids | |
if torch.cuda.is_available(): | |
gen_in = gen_in.cuda() | |
else: | |
gen_in = gen_in.cpu() | |
in_tokens = len(gen_in) | |
print('len token',in_tokens) | |
with torch.no_grad(): | |
generated_ids = generator( | |
gen_in, | |
max_new_tokens=200, | |
use_cache=True, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1, | |
do_sample=True, | |
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx: | |
temperature=0.5, # default: 1.0 | |
top_k = 50, # default: 50 | |
top_p = 1.0, # default: 1.0 | |
early_stopping=True, | |
) | |
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element? | |
text_without_prompt = generated_text[len(fulltext):] | |
response = text_without_prompt | |
response = response.split(human_invitation)[0] | |
response.strip() | |
print(invitation + response) | |
print("") | |
return response | |
def predict(input, chatbot, state): | |
print('predict state: ', state) | |
# input์ ํ๊ตญ์ด๊ฐ detect ๋๋ฉด ์์ด๋ก ๋ณ๊ฒฝ, ์๋๋ฉด ๊ทธ๋๋ก | |
is_kor = True | |
if googletrans.Translator().detect(input).lang == 'ko': | |
en_input = translator.translate(input, src='ko', dest='en').text | |
else: | |
en_input = input | |
is_kor = False | |
response = chatdoctor(en_input, state) | |
if is_kor: | |
ko_response = translator.translate(response, src='en', dest='ko').text | |
else: | |
ko_response = response | |
state.append(response) | |
chatbot.append((input, ko_response)) | |
return chatbot, state | |
load_model("mnc-ai/chatdoctor") | |
with gr.Blocks() as demo: | |
gr.Markdown("""<h1><center>์ฑ ๋ฅํฐ์ ๋๋ค. ์ด๋๊ฐ ๋ถํธํ์ ๊ฐ์?</center></h1> | |
""") | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ฐ๊ณ ์ํฐ").style(container=False) | |
clear = gr.Button("์๋ด ์๋ก ์์") | |
txt.submit(predict, inputs=[txt, chatbot, state], outputs=[chatbot, state], queue=False ) | |
txt.submit(lambda x: "", txt, txt) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
clear.click(lambda x: "", txt, txt) | |
demo.launch() | |