Spaces:
Paused
Paused
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import gradio as gr | |
#from transformers import pipeline | |
import torch | |
from utils import * | |
from presets import * | |
from huggingface_hub import login | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
#antwort="" | |
###################################################################### | |
#Modelle und Tokenizer | |
#Hugging Chat nutzen | |
# Create a chatbot connection | |
#chatbot = hugchat.ChatBot(cookie_path="cookies.json") | |
#Alternativ mit beliebigen Modellen: | |
#base_model = "project-baize/baize-v2-7b" #load_8bit = False (in load_tokenizer_and_model) | |
base_model = "meta-llama/Llama-2-13b" | |
#base_model = "MAGAer13/mPLUG-Owl" #load_8bit = False (in load_tokenizer_and_model) | |
#base_model = "alexkueck/li-tis-tuned-2" #load_8bit = False (in load_tokenizer_and_model) | |
#base_model = "TheBloke/airoboros-13B-HF" #load_8bit = False (in load_tokenizer_and_model) | |
#base_model = "EleutherAI/gpt-neo-1.3B" #load_8bit = False (in load_tokenizer_and_model) | |
#base_model = "TheBloke/airoboros-13B-HF" #load_8bit = True | |
#base_model = "TheBloke/vicuna-13B-1.1-HF" #load_8bit = ? | |
#following runs only on GPU upgrade | |
#base_model = "TheBloke/airoboros-65B-gpt4-1.3-GPTQ" #model_basename = "airoboros-65b-gpt4-1.3-GPTQ-4bit--1g.act.order" | |
#base_model = "lmsys/vicuna-13b-v1.3" | |
#base_model = "gpt2-xl" # options: ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] | |
#################################### | |
#Model und Tokenzier laden | |
tokenizer,model,device = load_tokenizer_and_model(base_model,False) | |
################################ | |
#Alternativ: Model und Tokenizer für GPT2 | |
#tokenizer,model,device = load_tokenizer_and_model_gpt2(base_model,False) | |
#Alternativ bloke gpt3 und4 - only with GPU upgarde | |
#tokenizer,model,device = load_tokenizer_and_model_bloke_gpt(base_model, "airoboros-65b-gpt4-1.3-GPTQ-4bit--1g.act.order") | |
#Alternativ Model und Tokenzier laden für Baize | |
#tokenizer,model,device = load_tokenizer_and_model_Baize(base_model,False) | |
######################################################################## | |
#Chat KI nutzen, um Text zu generieren... | |
def predict(text, | |
chatbotGr, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
max_context_length_tokens,): | |
if text=="": | |
yield chatbotGr,history,"Empty context." | |
return | |
try: | |
model | |
except: | |
yield [[text,"No Model Found"]],[],"No Model Found" | |
return | |
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) | |
if inputs is None: | |
yield chatbotGr,history,"Input too long." | |
return | |
else: | |
prompt,inputs=inputs | |
begin_length = len(prompt) | |
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device) | |
torch.cuda.empty_cache() | |
#torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation | |
#hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht! | |
with torch.no_grad(): | |
#die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen | |
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): | |
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: | |
if "[|Human|]" in x: | |
x = x[:x.index("[|Human|]")].strip() | |
if "[|AI|]" in x: | |
x = x[:x.index("[|AI|]")].strip() | |
x = x.strip() | |
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]] | |
yield a, b, "Generating..." | |
if shared_state.interrupted: | |
shared_state.recover() | |
try: | |
yield a, b, "Stop: Success" | |
return | |
except: | |
pass | |
del input_ids | |
gc.collect() | |
torch.cuda.empty_cache() | |
try: | |
yield a,b,"Generate: Success" | |
except: | |
pass | |
def reset_chat(): | |
#id_new = chatbot.new_conversation() | |
#chatbot.change_conversation(id_new) | |
reset_textbox() | |
#wenn 'Stop' Button geklickt, dann Message dazu und das Eingabe-Fenster leeren | |
def cancel_outputing(): | |
reset_textbox() | |
return "Stop Done" | |
########################################################## | |
#Übersetzungs Ki nutzen | |
def translate(): | |
return "Kommt noch!" | |
#Programmcode KI | |
def coding(): | |
return "Kommt noch!" | |
####################################################################### | |
#Darstellung mit Gradio | |
with open("custom.css", "r", encoding="utf-8") as f: | |
customCSS = f.read() | |
with gr.Blocks(theme=small_and_beautiful_theme) as demo: | |
history = gr.State([]) | |
user_question = gr.State("") | |
gr.Markdown("KIs am LI - wähle aus, was du bzgl. KI-Bots ausprobieren möchtest!") | |
with gr.Tabs(): | |
with gr.TabItem("LI-Chat"): | |
with gr.Row(): | |
gr.HTML(title) | |
status_display = gr.Markdown("Erfolg", elem_id="status_display") | |
gr.Markdown(description_top) | |
with gr.Row(scale=1).style(equal_height=True): | |
with gr.Column(scale=5): | |
with gr.Row(scale=1): | |
chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%") | |
with gr.Row(scale=1): | |
with gr.Column(scale=12): | |
user_input = gr.Textbox( | |
show_label=False, placeholder="Gib deinen Text / Frage ein." | |
).style(container=False) | |
with gr.Column(min_width=100, scale=1): | |
submitBtn = gr.Button("Absenden") | |
with gr.Column(min_width=100, scale=1): | |
cancelBtn = gr.Button("Stoppen") | |
with gr.Row(scale=1): | |
emptyBtn = gr.Button( | |
"🧹 Neuer Chat", | |
) | |
with gr.Column(): | |
with gr.Column(min_width=50, scale=1): | |
with gr.Tab(label="Nur zum Testen:"): | |
gr.Markdown("# Parameter") | |
top_p = gr.Slider( | |
minimum=-0, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
interactive=True, | |
label="Top-p", | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
max_length_tokens = gr.Slider( | |
minimum=0, | |
maximum=512, | |
value=512, | |
step=8, | |
interactive=True, | |
label="Max Generation Tokens", | |
) | |
max_context_length_tokens = gr.Slider( | |
minimum=0, | |
maximum=4096, | |
value=2048, | |
step=128, | |
interactive=True, | |
label="Max History Tokens", | |
) | |
gr.Markdown(description) | |
with gr.TabItem("Übersetzungen"): | |
with gr.Row(): | |
gr.Textbox( | |
show_label=False, placeholder="Ist noch in Arbeit..." | |
).style(container=False) | |
with gr.TabItem("Code-Generierungen"): | |
with gr.Row(): | |
gr.Textbox( | |
show_label=False, placeholder="Ist noch in Arbeit..." | |
).style(container=False) | |
predict_args = dict( | |
fn=predict, | |
inputs=[ | |
user_question, | |
chatbotGr, | |
history, | |
top_p, | |
temperature, #Variation der Antworten - stand. 1.0 | |
max_length_tokens, | |
max_context_length_tokens, | |
], | |
outputs=[chatbotGr, history, status_display], | |
show_progress=True, | |
) | |
#neuer Chat | |
reset_args = dict( | |
#fn=reset_chat, inputs=[], outputs=[user_input, status_display] | |
fn=reset_textbox, inputs=[], outputs=[user_input, status_display] | |
) | |
# Chatbot | |
transfer_input_args = dict( | |
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True | |
) | |
#Listener auf Start-Click auf Button oder Return | |
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args) | |
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args) | |
#Listener, Wenn reset... | |
emptyBtn.click( | |
reset_state, | |
outputs=[chatbotGr, history, status_display], | |
show_progress=True, | |
) | |
emptyBtn.click(**reset_args) | |
#Berechnung oder Ausgabe anhalten (kann danach fortgesetzt werden) | |
cancelBtn.click(cancel_outputing, [], [status_display], cancels=[predict_event1,predict_event2]) | |
#cancelBtn.click(lambda: None, None, chatbotGr, queue=False) | |
demo.title = "LI Chat" | |
#demo.queue(concurrency_count=1).launch(share=True) | |
demo.queue(concurrency_count=1).launch(debug=True) | |