import streamlit as st import random from huggingface_hub import InferenceClient models = [ "google/gemma-7b", "google/gemma-7b-it", "google/gemma-2b", "google/gemma-2b-it" ] clients = [] for model in models: clients.append(InferenceClient(model)) def format_prompt(message, history): prompt = "" if history: for user_prompt, bot_response in history: prompt += f"user{user_prompt}" prompt += f"model{bot_response}" prompt += f"user{message}model" return prompt def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p): client = clients[client_choice] # Use the client_choice directly as an index if not history: history = [] hist_len = 0 if history: hist_len = len(history) generate_kwargs = dict( temperature=temp, max_new_tokens=tokens, top_p=top_p, repetition_penalty=rep_p, do_sample=True, seed=seed, ) formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = [] for response in stream: output.append(response.token.text) history.append((prompt, "".join(output))) st.write("".join(output)) # Display the accumulated output def clear_fn(): return None rand_val = random.randint(1, 1111111111111111) def check_rand(inp, val): if inp is True: return st.slider("Seed", 1, 1111111111111111, rand_val) else: return st.slider("Seed", 1, 1111111111111111, int(val)) st.title("Google Gemma Models") client_choice = st.selectbox("Models", range(len(models))) # Use index as the choice rand = st.checkbox("Random Seed", True) seed = check_rand(rand, rand_val) tokens = st.slider("Max new tokens", 0, 8000, 6400, 64) temp = st.slider("Temperature", 0.01, 1.0, 0.9, step=0.01) top_p = st.slider("Top-P", 0.01, 1.0, 0.9, step=0.01) rep_p = st.slider("Repetition Penalty", 0.1, 2.0, 1.0, step=0.1) sys_inp = st.text_input("System Prompt (optional)") inp = st.text_input("Prompt") btn = st.button("Chat") clear_btn = st.button("Clear") if btn: chat_inf(sys_inp, inp, None, client_choice, seed, temp, tokens, top_p, rep_p) if clear_btn: st.session_state.history = clear_fn()