galileo / app.py
hugo1234's picture
Update app.py
664d7c4
raw
history blame
8.19 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gradio as gr
#from transformers import pipeline
import torch
from utils import *
from presets import *
#antwort=""
######################################################################
#Modelle und Tokenizer
#Hugging Chat nutzen
# Create a chatbot connection
#chatbot = hugchat.ChatBot(cookie_path="cookies.json")
#Alternativ mit beliebigen Modellen:
#base_model = "project-baize/baize-v2-7b"
base_model = "EleutherAI/gpt-neo-1.3B"
tokenizer,model,device = load_tokenizer_and_model(base_model)
########################################################################
#Chat KI nutzen, um Text zu generieren...
def predict(text,
chatbotGr,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,):
if text=="":
yield chatbotGr,history,"Testo vuoto."
return
try:
model
except:
yield [[text,"Nessun modello trovato"]],[],"Nessun modello trovato"
return
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
if inputs is None:
yield chatbotGr,history,"Input troppo lungo."
return
else:
prompt,inputs=inputs
begin_length = len(prompt)
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
torch.cuda.empty_cache()
#torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
#hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
with torch.no_grad():
#die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
if "[|Human|]" in x:
x = x[:x.index("[|Human|]")].strip()
if "[|AI|]" in x:
x = x[:x.index("[|AI|]")].strip()
x = x.strip()
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
yield a, b, "Sto elaborando ..."
if shared_state.interrupted:
shared_state.recover()
try:
yield a, b, "Stop: OK"
return
except:
pass
del input_ids
gc.collect()
torch.cuda.empty_cache()
try:
yield a,b,"Generazione: OK"
except:
pass
def reset_chat():
#id_new = chatbot.new_conversation()
#chatbot.change_conversation(id_new)
reset_textbox()
##########################################################
def translate():
return "In costruzione"
def coding():
return "In costruzione"
#######################################################################
#Darstellung mit Gradio
with open("custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
gr.Markdown("Scegli cosa vuoi provare:")
with gr.Tabs():
with gr.TabItem("Chat"):
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("OK", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row(scale=1):
chatbotGr = gr.Chatbot(elem_id="Chat").style(height="100%")
with gr.Row(scale=1):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Inserisci il tuo testo / domanda"
).style(container=False)
with gr.Column(min_width=100, scale=1):
submitBtn = gr.Button("Invia")
with gr.Column(min_width=100, scale=1):
cancelBtn = gr.Button("Cancella")
with gr.Row(scale=1):
emptyBtn = gr.Button(
"🧹 Nuova Chat",
)
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="Parametri del modello"):
gr.Markdown("# Parametri")
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperatura",
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=512,
value=512,
step=8,
interactive=True,
label="Numeno massimo di parole",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
step=128,
interactive=True,
label="Numero massimo di parole memorizzate",
)
gr.Markdown(description)
with gr.TabItem("Traduzioni"):
with gr.Row():
gr.Textbox(
show_label=False, placeholder="In costruzione ..."
).style(container=False)
with gr.TabItem("Generazione di codice"):
with gr.Row():
gr.Textbox(
show_label=False, placeholder="In costruzione ..."
).style(container=False)
predict_args = dict(
fn=predict,
inputs=[
user_question,
chatbotGr,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
],
outputs=[chatbotGr, history, status_display],
show_progress=True,
)
#neuer Chat
reset_args = dict(
#fn=reset_chat, inputs=[], outputs=[user_input, status_display]
fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
)
# Chatbot
transfer_input_args = dict(
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
)
#Listener auf Start-Click auf Button oder Return
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
#Listener, Wenn reset...
emptyBtn.click(
reset_state,
outputs=[chatbotGr, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
demo.title = "Chat"
#demo.queue(concurrency_count=1).launch(share=True)
demo.queue(concurrency_count=1).launch(debug=True)