Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
import os | |
hf_token = os.getenv("YOUR_HF_TOKEN") | |
# Load model and tokenizer | |
print("Loading model and tokenizer...") | |
model_path = "microsoft/Phi-4-mini-instruct" # Can be changed to local path "./Phi-4-Mini-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
padding_side="left", | |
token=hf_token, | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
attn_implementation="flash_attention_2", | |
torch_dtype="auto", | |
token=hf_token, | |
trust_remote_code=True | |
) | |
# Create pipeline for easier inference | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
print("Model and tokenizer loaded successfully!") | |
# Format chat history to messages format | |
def format_chat_history(message, history): | |
messages = [ | |
{"role": "system", "content": "You are a helpful AI assistant."} | |
] | |
# Add chat history | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
return messages | |
# Streaming response generator | |
def predict(message, history): | |
messages = format_chat_history(message, history) | |
generation_args = { | |
"max_new_tokens": 1024, | |
"return_full_text": False, | |
"temperature": 0.001, | |
"top_p": 1.0, | |
"do_sample": True, | |
"streamer": None, # Will be set in the generator | |
} | |
# Initialize an empty response | |
partial_message = "" | |
history_with_message = history + [[message, partial_message]] | |
# Create a TextIteratorStreamer for streaming generation | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_args["streamer"] = streamer | |
# Start a separate thread for generation | |
thread = Thread(target=pipe, args=(messages,), kwargs=generation_args) | |
thread.start() | |
# Stream the response | |
for new_text in streamer: | |
partial_message += new_text | |
yield history + [[message, partial_message]] | |
# Create the Gradio interface | |
css = """ | |
.chatbot-container {max-width: 800px; margin: auto;} | |
.chat-header {text-align: center; margin-bottom: 20px;} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML("<div class='chat-header'><h1>Phi-4 Mini Chatbot</h1></div>") | |
with gr.Column(elem_classes="chatbot-container"): | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox(placeholder="Type your message here...", label="Input") | |
clear = gr.Button("Clear Conversation") | |
msg.submit(predict, [msg, chatbot], [chatbot], queue=True, api_name="chat").then( | |
lambda: "", None, [msg] | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
""") | |
# Launch the app | |
demo.launch(share=True) # Set share=False if you don't want a public link |