Spaces:
Runtime error
Runtime error
import streamlit as st | |
import random | |
from huggingface_hub import InferenceClient | |
models = [ | |
"google/gemma-7b", | |
"google/gemma-7b-it", | |
"google/gemma-2b", | |
"google/gemma-2b-it" | |
] | |
clients = [] | |
for model in models: | |
clients.append(InferenceClient(model)) | |
def format_prompt(message, history): | |
prompt = "" | |
if history: | |
for user_prompt, bot_response in history: | |
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>" | |
prompt += f"<start_of_turn>model{bot_response}" | |
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model" | |
return prompt | |
def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p): | |
client = clients[client_choice] # Use the client_choice directly as an index | |
if not history: | |
history = [] | |
hist_len = 0 | |
if history: | |
hist_len = len(history) | |
generate_kwargs = dict( | |
temperature=temp, | |
max_new_tokens=tokens, | |
top_p=top_p, | |
repetition_penalty=rep_p, | |
do_sample=True, | |
seed=seed, | |
) | |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, | |
return_full_text=False) | |
output = [] | |
for response in stream: | |
output.append(response.token.text) | |
history.append((prompt, "".join(output))) | |
st.write("".join(output)) # Display the accumulated output | |
def clear_fn(): | |
return None | |
rand_val = random.randint(1, 1111111111111111) | |
def check_rand(inp, val): | |
if inp is True: | |
return st.slider("Seed", 1, 1111111111111111, rand_val) | |
else: | |
return st.slider("Seed", 1, 1111111111111111, int(val)) | |
st.title("Google Gemma Models") | |
client_choice = st.selectbox("Models", range(len(models))) # Use index as the choice | |
rand = st.checkbox("Random Seed", True) | |
seed = check_rand(rand, rand_val) | |
tokens = st.slider("Max new tokens", 0, 8000, 6400, 64) | |
temp = st.slider("Temperature", 0.01, 1.0, 0.9, step=0.01) | |
top_p = st.slider("Top-P", 0.01, 1.0, 0.9, step=0.01) | |
rep_p = st.slider("Repetition Penalty", 0.1, 2.0, 1.0, step=0.1) | |
sys_inp = st.text_input("System Prompt (optional)") | |
inp = st.text_input("Prompt") | |
btn = st.button("Chat") | |
clear_btn = st.button("Clear") | |
if btn: | |
chat_inf(sys_inp, inp, None, client_choice, seed, temp, tokens, top_p, rep_p) | |
if clear_btn: | |
st.session_state.history = clear_fn() |