test_Gradio / app.py
negismohit123's picture
Update app.py
44353c3 verified
raw
history blame contribute delete
No virus
2.52 kB
import streamlit as st
import random
from huggingface_hub import InferenceClient
models = [
"google/gemma-7b",
"google/gemma-7b-it",
"google/gemma-2b",
"google/gemma-2b-it"
]
clients = []
for model in models:
clients.append(InferenceClient(model))
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
return prompt
def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
client = clients[client_choice] # Use the client_choice directly as an index
if not history:
history = []
hist_len = 0
if history:
hist_len = len(history)
generate_kwargs = dict(
temperature=temp,
max_new_tokens=tokens,
top_p=top_p,
repetition_penalty=rep_p,
do_sample=True,
seed=seed,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
return_full_text=False)
output = []
for response in stream:
output.append(response.token.text)
history.append((prompt, "".join(output)))
st.write("".join(output)) # Display the accumulated output
def clear_fn():
return None
rand_val = random.randint(1, 1111111111111111)
def check_rand(inp, val):
if inp is True:
return st.slider("Seed", 1, 1111111111111111, rand_val)
else:
return st.slider("Seed", 1, 1111111111111111, int(val))
st.title("Google Gemma Models")
client_choice = st.selectbox("Models", range(len(models))) # Use index as the choice
rand = st.checkbox("Random Seed", True)
seed = check_rand(rand, rand_val)
tokens = st.slider("Max new tokens", 0, 8000, 6400, 64)
temp = st.slider("Temperature", 0.01, 1.0, 0.9, step=0.01)
top_p = st.slider("Top-P", 0.01, 1.0, 0.9, step=0.01)
rep_p = st.slider("Repetition Penalty", 0.1, 2.0, 1.0, step=0.1)
sys_inp = st.text_input("System Prompt (optional)")
inp = st.text_input("Prompt")
btn = st.button("Chat")
clear_btn = st.button("Clear")
if btn:
chat_inf(sys_inp, inp, None, client_choice, seed, temp, tokens, top_p, rep_p)
if clear_btn:
st.session_state.history = clear_fn()