jcrissa's picture
Update app.py
5afb622 verified
raw
history blame
2.87 kB
import gradio as grad
import torch
import spaces
from unsloth import FastLanguageModel
from transformers import AutoTokenizer
@spaces.GPU
def dummy(): # just a dummy
pass
# Load your fine-tuned Phi-3 model from Hugging Face
MODEL_NAME = "jcrissa/phi3-new-t2i"
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"
def load_phi3_model():
try:
# Load the Phi-3 model and tokenizer from Hugging Face
model, tokenizer = FastLanguageModel.from_pretrained(
MODEL_NAME,
max_seq_length=4096, # Ensure it matches your fine-tuning
dtype=torch.float16 if device == "cuda" else torch.float32 # Use `float16` for GPU, `float32` for CPU
)
model.to(device)
# Prepare the model for inference
model = FastLanguageModel.for_inference(model) # This is the necessary line
# Configure tokenizer settings
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return model, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None
# Load the model and tokenizer, ensure error handling
phi3_model, phi3_tokenizer = load_phi3_model()
if phi3_model is None or phi3_tokenizer is None:
raise RuntimeError("Model and tokenizer could not be loaded. Please check the Hugging Face model path or network connection.")
# Function to generate text using Phi-3
def generate(plain_text):
try:
# Tokenize input text and move to the device
input_ids = phi3_tokenizer(plain_text.strip(), return_tensors="pt").input_ids.to(device)
eos_id = phi3_tokenizer.eos_token_id
# Generate the output from the model using sampling instead of beam search
outputs = phi3_model.generate(
input_ids,
do_sample=True, # Use sampling instead of beam search
max_new_tokens=75,
num_return_sequences=1,
eos_token_id=eos_id,
pad_token_id=eos_id,
length_penalty=-1.0
)
# Decode and return the generated text
output_text = phi3_tokenizer.decode(outputs[0], skip_special_tokens=True)
return output_text.strip()
except Exception as e:
return f"Error during text generation: {e}"
# Setup Gradio Interface
txt = grad.Textbox(lines=1, label="Input Text", placeholder="Enter your prompt")
out = grad.Textbox(lines=1, label="Generated Text")
grad.Interface(
fn=generate,
inputs=txt,
outputs=out,
title="Fine-Tuned Phi-3 Model",
description="This demo uses a fine-tuned Phi-3 model to optimize text prompts.",
flagging_mode="never", # Replace `allow_flagging` with `flagging_mode`
cache_examples=False,
theme="default"
).launch(share=True) # Use `queue=True` instead of `enable_queue`