File size: 5,076 Bytes
1efea19
6d3a2f2
1efea19
2732ef7
1efea19
 
d8f007f
 
1efea19
 
 
 
 
2732ef7
1efea19
 
 
 
 
d8f007f
1efea19
 
 
 
 
8020e39
ea616ac
fa42556
4c090d8
ceb2dbe
9b7d4f4
1efea19
 
8020e39
 
 
aeeecba
63ee5d3
3fa478b
1eb2d22
1efea19
8020e39
1efea19
 
 
ceb2dbe
d8f007f
1efea19
 
 
 
 
 
 
 
 
 
d8f007f
1efea19
d8f007f
 
1efea19
d8f007f
 
1efea19
d8f007f
 
1efea19
 
 
 
 
 
 
8e1d1e9
 
 
 
1efea19
d8f007f
1efea19
 
d8f007f
1efea19
 
 
 
 
 
8020e39
 
4df0539
1db99e2
ab74be1
79357a0
 
 
 
 
1efea19
 
8e1d1e9
 
 
 
1efea19
d8f007f
1efea19
 
 
 
 
2732ef7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config
import torch
import time
import base64

st.set_page_config(page_title="LIA Demo", layout="wide")
st.markdown("<h1 style='text-align: center;'>Ask LeoNardo!</h1>", unsafe_allow_html=True)

# Load both GIFs in base64 format
def load_gif_base64(path):
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")
    
# Placeholder for GIF HTML
gif_html = st.empty()
caption = st.empty()

gif_html.markdown(
    f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
    unsafe_allow_html=True,
)

@st.cache_resource
def load_model():
    # model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    # model_id = "deepseek-ai/deepseek-llm-7b-chat"
    # model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    
    model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        # device_map=None,
        # torch_dtype=torch.float32
        device_map="auto",
        torch_dtype=torch.float16,
        # quantization_config=quantization_config,
        # attn_implementation="flash_attention_2",
        trust_remote_code = True
    )
    # model.to("cpu")
    return tokenizer, model

tokenizer, model = load_model()
prompt = st.text_area("Enter your prompt:", "What company is Leonardo S.p.A.?")
# Example prompt selector
# examples = {
#     "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
#     "💻 Code": "Write a Python function to sort a list using bubble sort.",
#     "📜 Poem": "Write a haiku about large language models.",
#     "🤖 Explain": "Explain what a transformer is in simple terms.",
#     "🔍 Fact": "Who won the FIFA World Cup in 2022?"
# }

# selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
# Add before generation
# col1, col2, col3 = st.columns(3)

# with col1:
#     temperature = st.slider("Temperature", 0.1, 1.5, 1.0)

# with col2:
#     max_tokens = st.slider("Max tokens", 50, 500, 100)

# with col3:
#     top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
# if selected_example != "✍️ Custom input":
#     prompt = examples[selected_example]
# else:
#     prompt = st.text_area("Enter your prompt:")

if st.button("Generate"):
    # Swap to rotating GIF
    # gif_html.markdown(
    #     f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
    #     unsafe_allow_html=True,
    # )
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExMXViMm02MnR6bGJ4c2h3ajYzdWNtNXNtYnNic3lnN2xyZzlzbm9seSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/k32ddF9WVs44OUaZAm/giphy.gif' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.markdown("<p style='text-align: center; margin-top: 20px;'>LeoNardo is thinking... 🌀</p>", unsafe_allow_html=True)


    # Generate text

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(   **inputs,
                                    # max_new_tokens=100, 
                                    max_new_tokens=512, 
                                    do_sample=True,
                                    temperature=0.6,
                                    top_p=0.95,
                                    top_k=50, 
                                    num_return_sequences=1, 
                                    eos_token_id=tokenizer.eos_token_id
                                    )

    # Back to still
    # gif_html.markdown(
    #     f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
    #     unsafe_allow_html=True,
    # )
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.empty()


    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
     # Set up placeholder for streaming effect
    output_placeholder = st.empty()
    streamed_text = ""

    for word in decoded_output.split(" "):
        streamed_text += word + " "
        output_placeholder.markdown("### ✨ Output:\n\n" + streamed_text + "▌")
        # slight delay
        time.sleep(0.03)

    # Final cleanup (remove blinking cursor)
    output_placeholder.markdown("### ✨ Output:\n\n" + streamed_text)