Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import base64 | |
| st.set_page_config(page_title="Gemma Demo", layout="wide") | |
| # Model selection (STUBBED behavior) | |
| model_option = st.selectbox( | |
| "Choose a Gemma to reveal hidden truths:", | |
| ["gemma-2b-it (Instruct)", "gemma-2b", "gemma-7b", "gemma-7b-it"], | |
| index=0, | |
| help="Stubbed selection β only gemma-2b-it will load for now." | |
| ) | |
| st.markdown("<h1 style='text-align: center;'>Portal to Gemma</h1>", unsafe_allow_html=True) | |
| # Load both GIFs in base64 format | |
| def load_gif_base64(path): | |
| with open(path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| # still_gem_b64 = load_gif_base64("assets/stillGem.gif") | |
| # rotating_gem_b64 = load_gif_base64("assets/rotatingGem.gif") | |
| # Placeholder for GIF HTML | |
| gif_html = st.empty() | |
| caption = st.empty() | |
| # Initially show still gem | |
| # gif_html.markdown( | |
| # f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>", | |
| # unsafe_allow_html=True, | |
| # ) | |
| gif_html.markdown( | |
| f"<div style='text-align:center;'><img src='https://media3.giphy.com/media/v1.Y2lkPTc5MGI3NjExMG00dmlwbjZsemZ5Mnh2eTIwOGNyYncwbGNqd3U3aHhiNGYxYjgwbCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/WmJtDY3xgYVgXgQZYc/giphy.gif' width='300'></div>", | |
| unsafe_allow_html=True, | |
| ) | |
| def load_model(): | |
| # As Gemma is gated, we will show functionality of the demo using DeepSeek-R1-Distill-Qwen-1.5B model | |
| # model_id = "google/gemma-2b-it" | |
| # tokenizer = AutoTokenizer.from_pretrained(model_id, token=True) | |
| model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map=None, | |
| torch_dtype=torch.float32 | |
| ) | |
| model.to("cpu") | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| prompt = st.text_area("Enter your prompt:", "What is Gemma?") | |
| # # Example prompt selector | |
| # examples = { | |
| # "π§ Summary": "Summarize the history of AI in 5 bullet points.", | |
| # "π» Code": "Write a Python function to sort a list using bubble sort.", | |
| # "π Poem": "Write a haiku about large language models.", | |
| # "π€ Explain": "Explain what a transformer is in simple terms.", | |
| # "π Fact": "Who won the FIFA World Cup in 2022?" | |
| # } | |
| # selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["βοΈ Custom input"]) | |
| # Add before generation | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| temperature = st.slider("Temperature", 0.1, 1.5, 1.0) | |
| with col2: | |
| max_tokens = st.slider("Max tokens", 50, 500, 100) | |
| with col3: | |
| top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95) | |
| # if selected_example != "βοΈ Custom input": | |
| # prompt = examples[selected_example] | |
| # else: | |
| # prompt = st.text_area("Enter your prompt:") | |
| if st.button("Generate"): | |
| # Swap to rotating GIF | |
| # gif_html.markdown( | |
| # f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>", | |
| # unsafe_allow_html=True, | |
| # ) | |
| gif_html.markdown( | |
| f"<div style='text-align:center;'><img src='https://media4.giphy.com/media/v1.Y2lkPTc5MGI3NjExaXB0ZTEycW1yYWhvZWExdHFyNzBnemdtdm80NzY0MGg1ZnkyNTRqbiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/FIMRlbkylLyniVP7WG/giphy.gif' width='300'></div>", | |
| unsafe_allow_html=True, | |
| ) | |
| caption.markdown("<p style='text-align: center;'>Gemma is thinking... π</p>", unsafe_allow_html=True) | |
| # Generate text | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p) | |
| # Back to still | |
| # gif_html.markdown( | |
| # f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>", | |
| # unsafe_allow_html=True, | |
| # ) | |
| gif_html.markdown( | |
| f"<div style='text-align:center;'><img src='https://media3.giphy.com/media/v1.Y2lkPTc5MGI3NjExMG00dmlwbjZsemZ5Mnh2eTIwOGNyYncwbGNqd3U3aHhiNGYxYjgwbCZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/WmJtDY3xgYVgXgQZYc/giphy.gif' width='300'></div>", | |
| unsafe_allow_html=True, | |
| ) | |
| caption.empty() | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| st.markdown("### β¨ Output:") | |
| st.write(result) |