DeepMount00's picture
Update app.py
a71fb86 verified
raw
history blame
10.3 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Model initialization
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "DeepMount00/Lexora-Lite-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
model.eval()
# Custom CSS
CUSTOM_CSS = """
/* Base styles */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Inter', sans-serif;
background-color: #f8fafc;
color: #1e293b;
}
/* Container styles */
.container {
max-width: 1000px !important;
margin: auto !important;
padding: 2rem !important;
}
/* Header styles */
.header-container {
background: linear-gradient(135deg, #1e3a8a 0%, #3b82f6 100%);
padding: 2.5rem;
border-radius: 1rem;
margin-bottom: 2rem;
color: white;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
}
.header-title {
font-size: 2.5rem;
font-weight: 700;
margin-bottom: 1.5rem;
text-align: center;
letter-spacing: -0.025em;
}
/* Model info styles */
.model-info {
background: white;
padding: 1.75rem;
border-radius: 0.75rem;
margin-top: 1.5rem;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.model-info h2 {
font-size: 1.5rem;
font-weight: 600;
color: #1e3a8a;
margin-bottom: 1rem;
}
.model-info p {
color: #374151;
line-height: 1.6;
font-size: 1.1rem;
}
.model-info a {
color: #2563eb;
font-weight: 600;
text-decoration: none;
transition: color 0.2s;
}
.model-info a:hover {
color: #1d4ed8;
text-decoration: underline;
}
/* Chat container styles */
.chat-container {
border: 1px solid #e5e7eb;
border-radius: 1rem;
background: white;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
}
/* Message styles */
.message {
padding: 1.25rem;
margin: 0.75rem;
border-radius: 0.5rem;
font-size: 1.05rem;
line-height: 1.6;
}
.user-message {
background: #f3f4f6;
border-left: 4px solid #3b82f6;
}
.assistant-message {
background: #dbeafe;
border-left: 4px solid #1d4ed8;
}
/* Controls container styles */
.controls-container {
background: #f8fafc;
padding: 1.75rem;
border-radius: 0.75rem;
margin-top: 1.5rem;
border: 1px solid #e5e7eb;
}
/* Slider styles */
.slider-label {
font-weight: 600;
color: #374151;
margin-bottom: 0.5rem;
font-size: 1.05rem;
}
/* Button styles */
.duplicate-button {
background: #2563eb !important;
color: white !important;
padding: 0.875rem 1.75rem !important;
border-radius: 0.5rem !important;
font-weight: 600 !important;
font-size: 1.05rem !important;
transition: all 0.2s !important;
border: none !important;
cursor: pointer !important;
display: inline-flex !important;
align-items: center !important;
justify-content: center !important;
text-align: center !important;
box-shadow: 0 2px 4px rgba(37, 99, 235, 0.2) !important;
}
.duplicate-button:hover {
background: #1d4ed8 !important;
transform: translateY(-1px) !important;
box-shadow: 0 4px 6px rgba(37, 99, 235, 0.3) !important;
}
/* Input field styles */
.input-textarea {
border: 2px solid #e5e7eb !important;
border-radius: 0.5rem !important;
padding: 1rem !important;
font-size: 1.05rem !important;
transition: border-color 0.2s !important;
}
.input-textarea:focus {
border-color: #3b82f6 !important;
outline: none !important;
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1) !important;
}
/* System message styles */
.system-message {
background: #f8fafc;
border: 1px solid #e5e7eb;
border-radius: 0.5rem;
padding: 1rem;
margin-bottom: 1rem;
}
"""
# HTML Description
DESCRIPTION = '''
<div class="header-container">
<h1 class="header-title">Lexora-Lite-3B</h1>
<div class="model-info">
<h2>About the Model</h2>
<p>
Welcome to the demonstration of <a href="https://huggingface.co/DeepMount00/Lexora-Lite-3B">Lexora-Lite-3B Chat ITA</a>,
currently the leading open-source large language model for the Italian language. This model represents the state-of-the-art
in Italian natural language processing, combining powerful language understanding with efficient performance.
</p>
<p style="margin-top: 1rem;">
View its performance metrics and compare it with other models on the
<a href="https://huggingface.co/spaces/FinancialSupport/open_ita_llm_leaderboard">official leaderboard</a>.
</p>
</div>
</div>
'''
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 2048,
temperature: float = 0.0001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def create_chat_interface():
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="blue",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
radius_size=gr.themes.sizes.radius_sm,
)
with gr.Blocks(css=CUSTOM_CSS, theme=theme) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown(DESCRIPTION)
with gr.Column(elem_classes="chat-container"):
additional_inputs = [
gr.Textbox(
value="",
label="System Message",
elem_classes="system-message",
render=False,
),
]
# Create controls without context manager
controls = gr.Column(elem_classes="controls-container")
with controls:
additional_inputs.extend([
gr.Slider(
label="Maximum New Tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
elem_classes="slider-label",
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=0.001,
elem_classes="slider-label",
),
gr.Slider(
label="Top-p (Nucleus Sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
elem_classes="slider-label",
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
elem_classes="slider-label",
),
gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
elem_classes="slider-label",
),
])
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=additional_inputs,
examples=[
["Ciao! Come stai?"],
["Raccontami una breve storia."],
["Qual è il tuo piatto italiano preferito?"],
],
cache_examples=False,
)
gr.DuplicateButton(
value="Duplicate Space for Private Use",
elem_classes="duplicate-button",
elem_id="duplicate-button",
)
return demo
if __name__ == "__main__":
demo = create_chat_interface()
demo.queue(max_size=20).launch()