LIStarCode / app.py
alexkueck's picture
Update app.py
a45a332
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gradio as gr
#from transformers import pipeline
import torch
from utils import *
from presets import *
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
######################################################################
#Modelle und Tokenizer
#Alternativ mit beliebigen Modellen:
#base_model = "project-baize/baize-v2-7b"
base_model = "bigcode/starcoder"
tokenizer,model,device = load_tokenizer_and_model(base_model)
#pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
#print( pipe("def hello():") )
########################################################################
#Chat KI nutzen, um Text zu generieren...
def predict(text,
chatbotGr,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,):
global model, tokenizer, device
#wenn eingabe leer - nix tun
if text=="":
yield history,"Empty context."
return
#wenn Model nicht gefunden -> Fehler
try:
model
except:
yield [],"No Model Found"
return
#Prompt generieren -> mit Kontext bezogen auch auf vorhergehende Eingaben in dem chat
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
if inputs is None:
yield history,"Input too long."
return
else:
prompt,inputs=inputs
begin_length = len(prompt)
#####################################################################################################
#ist glaube ich unnötig, da ich mit Pipeline arbeiten -> mal schauen, ich behalte es noch...
"""
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
torch.cuda.empty_cache()
#torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
#hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
with torch.no_grad():
#die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
if "[|Human|]" in x:
x = x[:x.index("[|Human|]")].strip()
if "[|AI|]" in x:
x = x[:x.index("[|AI|]")].strip()
x = x.strip()
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
yield a, b, "Generating..."
if shared_state.interrupted:
shared_state.recover()
try:
yield a, b, "Stop: Success"
return
except:
pass
del input_ids
gc.collect()
torch.cuda.empty_cache()
try:
yield a,b,"Generate: Success"
except:
pass
"""
##########################################################################
#Prompt ist erzeugt, nun mit pipeline eine Antwort von der KI bekommen!
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0, do_sample=True, top_p=0.95)
bot_message = pipe(prompt)
#chatbot - history erweitern und an chatbotGr zurückschicken
history = history.append((text, bot_message))
return history, "Erfolg!"
#neuen Chat beginnen
def reset_chat():
#id_new = chatbot.new_conversation()
#chatbot.change_conversation(id_new)
reset_textbox()
#######################################################################
#Darstellung mit Gradio
with open("custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
gr.Markdown("KIs am LI - hier um Programmcode generieren zu lassen!")
with gr.Tabs():
with gr.TabItem("LI-Coding-KI"):
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Erfolg", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row(scale=1):
chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%")
with gr.Row(scale=1):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Gib deinen Text / Frage ein."
).style(container=False)
with gr.Column(min_width=100, scale=1):
submitBtn = gr.Button("Absenden")
with gr.Column(min_width=100, scale=1):
cancelBtn = gr.Button("Stoppen")
with gr.Row(scale=1):
emptyBtn = gr.Button(
"🧹 Neuer Chat",
)
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="Parameter zum Model"):
gr.Markdown("# Parameter für Testzwecke")
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperature",
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=512,
value=512,
step=8,
interactive=True,
label="Max Generation Tokens",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
step=128,
interactive=True,
label="Max History Tokens",
)
gr.Markdown(description)
predict_args = dict(
fn=predict,
inputs=[
user_input,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
],
outputs=[chatbotGr, status_display],
show_progress=True,
)
#neuer Chat
reset_args = dict(
#fn=reset_chat, inputs=[], outputs=[user_input, status_display]
fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
)
# Chatbot
transfer_input_args = dict(
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
)
#Listener auf Start-Click auf Button oder Return
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
#Listener, Wenn reset...
emptyBtn.click(
reset_state,
outputs=[chatbotGr, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
demo.title = "LI Coding KI"
#demo.queue(concurrency_count=1).launch(share=True)
demo.queue(concurrency_count=1).launch(debug=True)