ChatBotLI2Klein / app.py
alexkueck's picture
Update app.py
f9856aa
raw
history blame
6.03 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gradio as gr
#from transformers import pipeline
import torch
from utils import *
from presets import *
#antwort=""
# Create a chatbot connection
#chatbot = hugchat.ChatBot(cookie_path="cookies.json")
#Alternativ mit beliebigen Modellen:
#base_model = "project-baize/baize-v2-7b"
base_model = "EleutherAI/gpt-neo-1.3B"
tokenizer,model,device = load_tokenizer_and_model(base_model)
def predict(text,
chatbotGr,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,):
if text=="":
yield chatbotGr,history,"Empty context."
return
try:
model
except:
yield [[text,"No Model Found"]],[],"No Model Found"
return
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
if inputs is None:
yield chatbotGr,history,"Input too long."
return
else:
prompt,inputs=inputs
begin_length = len(prompt)
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
torch.cuda.empty_cache()
with torch.no_grad():
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
if "[|Human|]" in x:
x = x[:x.index("[|Human|]")].strip()
if "[|AI|]" in x:
x = x[:x.index("[|AI|]")].strip()
x = x.strip()
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
yield a, b, "Generating..."
if shared_state.interrupted:
shared_state.recover()
try:
yield a, b, "Stop: Success"
return
except:
pass
del input_ids
gc.collect()
torch.cuda.empty_cache()
#print(text)
#print(x)
#print("="*80)
try:
yield a,b,"Generate: Success"
except:
pass
def reset_chat():
id_new = chatbot.new_conversation()
chatbot.change_conversation(id_new)
reset_textbox()
with gr.Blocks(theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Erfolg", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row(scale=1):
chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%")
with gr.Row(scale=1):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Gib deinen Text / Frage ein."
).style(container=False)
with gr.Column(min_width=90, scale=1):
submitBtn = gr.Button("Absenden")
with gr.Column(min_width=90, scale=1):
cancelBtn = gr.Button("Stoppen")
with gr.Row(scale=1):
emptyBtn = gr.Button(
"🧹 Neuer Chat",
)
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="Parameter zum Model"):
gr.Markdown("# Parameters")
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperature",
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=512,
value=512,
step=8,
interactive=True,
label="Max Generation Tokens",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
step=128,
interactive=True,
label="Max History Tokens",
)
gr.Markdown(description)
predict_args = dict(
fn=predict,
inputs=[
user_question,
chatbotGr,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
],
outputs=[chatbotGr, history, status_display],
show_progress=True,
)
#neuer Chat
reset_args = dict(
fn=reset_chat, inputs=[], outputs=[user_input, status_display]
)
# Chatbot
transfer_input_args = dict(
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
)
#Listener auf Start-Click auf Button oder Return
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
#Listener, Wenn reset...
emptyBtn.click(
reset_state,
outputs=[chatbotGr, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
demo.title = "LI Chat"
#demo.queue(concurrency_count=1).launch(share=True)
demo.queue(concurrency_count=1).launch(debug=True)