Spaces:
Sleeping
Sleeping
File size: 5,076 Bytes
1efea19 6d3a2f2 1efea19 2732ef7 1efea19 d8f007f 1efea19 2732ef7 1efea19 d8f007f 1efea19 8020e39 ea616ac fa42556 4c090d8 ceb2dbe 9b7d4f4 1efea19 8020e39 aeeecba 63ee5d3 3fa478b 1eb2d22 1efea19 8020e39 1efea19 ceb2dbe d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 8e1d1e9 1efea19 d8f007f 1efea19 d8f007f 1efea19 8020e39 4df0539 1db99e2 ab74be1 79357a0 1efea19 8e1d1e9 1efea19 d8f007f 1efea19 2732ef7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config
import torch
import time
import base64
st.set_page_config(page_title="LIA Demo", layout="wide")
st.markdown("<h1 style='text-align: center;'>Ask LeoNardo!</h1>", unsafe_allow_html=True)
# Load both GIFs in base64 format
def load_gif_base64(path):
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
# Placeholder for GIF HTML
gif_html = st.empty()
caption = st.empty()
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
@st.cache_resource
def load_model():
# model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# model_id = "deepseek-ai/deepseek-llm-7b-chat"
# model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
# device_map=None,
# torch_dtype=torch.float32
device_map="auto",
torch_dtype=torch.float16,
# quantization_config=quantization_config,
# attn_implementation="flash_attention_2",
trust_remote_code = True
)
# model.to("cpu")
return tokenizer, model
tokenizer, model = load_model()
prompt = st.text_area("Enter your prompt:", "What company is Leonardo S.p.A.?")
# Example prompt selector
# examples = {
# "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
# "💻 Code": "Write a Python function to sort a list using bubble sort.",
# "📜 Poem": "Write a haiku about large language models.",
# "🤖 Explain": "Explain what a transformer is in simple terms.",
# "🔍 Fact": "Who won the FIFA World Cup in 2022?"
# }
# selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
# Add before generation
# col1, col2, col3 = st.columns(3)
# with col1:
# temperature = st.slider("Temperature", 0.1, 1.5, 1.0)
# with col2:
# max_tokens = st.slider("Max tokens", 50, 500, 100)
# with col3:
# top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
# if selected_example != "✍️ Custom input":
# prompt = examples[selected_example]
# else:
# prompt = st.text_area("Enter your prompt:")
if st.button("Generate"):
# Swap to rotating GIF
# gif_html.markdown(
# f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
# unsafe_allow_html=True,
# )
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExMXViMm02MnR6bGJ4c2h3ajYzdWNtNXNtYnNic3lnN2xyZzlzbm9seSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/k32ddF9WVs44OUaZAm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
caption.markdown("<p style='text-align: center; margin-top: 20px;'>LeoNardo is thinking... 🌀</p>", unsafe_allow_html=True)
# Generate text
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate( **inputs,
# max_new_tokens=100,
max_new_tokens=512,
do_sample=True,
temperature=0.6,
top_p=0.95,
top_k=50,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
# Back to still
# gif_html.markdown(
# f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
# unsafe_allow_html=True,
# )
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
caption.empty()
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Set up placeholder for streaming effect
output_placeholder = st.empty()
streamed_text = ""
for word in decoded_output.split(" "):
streamed_text += word + " "
output_placeholder.markdown("### ✨ Output:\n\n" + streamed_text + "▌")
# slight delay
time.sleep(0.03)
# Final cleanup (remove blinking cursor)
output_placeholder.markdown("### ✨ Output:\n\n" + streamed_text) |