doubledsbv's picture
Update app.py
6457556 verified
raw
history blame contribute delete
No virus
2.97 kB
import gradio as gr
import runpod
import os
from dotenv import load_dotenv
load_dotenv()
runpod.api_key = os.getenv("RUNPOD_API_KEY")
def format_prompt(message, history):
messages = []
messages.append({"role": "system", "content": "Du bist ein freundlicher, hilfsbereiter Chatbot, der gerne Fragen korrekt und präzise beantwortet."})
for user_prompt, bot_response in history:
messages.append({"role": "user", "content": user_prompt})
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
return messages
def generate(
message: str,
history: list = [],
temperature=0.7,
max_tokens=1024,
top_p=0.95,
):
formatted_messages = format_prompt(message, history)
temperature = float(temperature)
temperature = max(temperature, 1e-2)
top_p = float(top_p)
top_k = 50
max_tokens = int(max_tokens)
input_payload = {
"input": {
"messages": formatted_messages,
"sampling_params": {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"max_tokens": max_tokens
},
}
}
try:
endpoint = runpod.Endpoint(os.getenv("ENDPOINT_ID"))
run_request = endpoint.run(input_payload)
stream_output = ''
for output in run_request.stream():
for t in output['choices'][0]['tokens']:
stream_output += t.strip()
yield stream_output
# print (t.strip())
return stream_output
except Exception as e:
print(f"An error occurred: {e}")
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.7,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=2048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
]
css = """
#mkd {
height: 750px;
overflow: auto;
border: 1px solid #ccc;
}
#component-6 {
height: 450px !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
title="Demo Kafka-7B-DARE-TIES-QLoRa-LaserRMT-DPO",
examples=[["Was ist der Sinn des Lebens?"], ["Was ist der Unterschied zwischen Quantenmechanik und Relativitätstheorie?"]]
)
demo.queue().launch(max_threads=100, debug=True)