thatGPT / app.py
gojiteji's picture
Update app.py
b614f80
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
modelname="gpt2-large"
config = AutoConfig.from_pretrained(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
model = AutoModelForCausalLM.from_pretrained(modelname,config=config)
def botsay(user_input):
prompt = "This is a conversation between Human and AI bot. AI's name is ThatGPT."
new_token_id=None
gen_tokens=""
new_token=""
j =6
length=0
limit = 128
thatid=5562
cont = True
last_apppended = False
cnt=0
disable_repeat_length= 5
disable_repeat_count = 2
tokens=[]
while(cont):
cnt+=1
prob = 1.0
input_ids=tokenizer(prompt+user_input+"\nAI:"+gen_tokens,return_tensors="pt").input_ids
length=len(input_ids)
if length >limit:
gen_tokens="⚠️sorry length limit. please reload the browser."
return gen_tokens
outs=model(input_ids=input_ids)
topk = torch.topk(outs.logits.squeeze()[-1,:],k=j+1).indices
if new_token =="that":
that_id = 326
elif new_token ==" that":
that_id = -1
elif new_token[-1:] ==" ":
that_id = 5562
else:
that_id = 326
if ("thatGPT" in gen_tokens[-12:]):
that_id = -1
if last_apppended:
that_id = -1
if that_id in topk:
new_token_id = that_id
else:
new_token_id = torch.argmax(outs.logits.squeeze()[-1,:])
new_token=tokenizer.decode(new_token_id)
new_token=tokenizer.decode(new_token_id)
prev_tokens=gen_tokens
gen_tokens+=new_token
if (cnt>10) and (disable_repeat_count<gen_tokens.count(gen_tokens[-disable_repeat_length:])):
gen_tokens=prev_tokens
new_token = tokenizer.decode(topk[torch.randint(5, (1,1)).item()])
gen_tokens+=new_token
if new_token_id==50256 or new_token_id==198 or new_token=="<|endoftext|>":
if ("that" not in gen_tokens):
gen_tokens = gen_tokens.replace("\n","").replace(".","")
gen_tokens += " that"
else:
cont = False
return gen_tokens.replace("<br>","").replace("AI:","").replace("\xa0","")
import gradio as gr
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def bot(history):
serial_history=""
for h in history:
serial_history+="\nHuman:"+h[0]
if h[1]==None:
break
serial_history+="\nAI:"+h[1].replace("<br>","")
response = botsay(serial_history)
history[-1][1] = response
serial_history+="\nAI:"+response
return history
with gr.Blocks() as demo:
gr.Markdown("# ThatGPT - AI always replies with \"that\" -")
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(
show_label=False,
placeholder="AI always replies with \"that\". It may take more than ten seconds.",
).style(container=False)
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
bot, chatbot, chatbot
)
demo.launch()