Monimoy's picture
Update app.py
6fceecb verified
# app.py
import spaces
import os
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import timm
from torchvision import transforms
#from llama_cpp import Llama
from peft import PeftModel
# 1. Model Definitions (Same as in training script)
class SigLIPImageEncoder(torch.nn.Module):
def __init__(self, model_name='resnet50', embed_dim=512, pretrained_path=None):
super().__init__()
self.model = timm.create_model(model_name, pretrained=False, num_classes=0, global_pool='avg') # pretrained=False
self.embed_dim = embed_dim
self.projection = torch.nn.Linear(self.model.num_features, embed_dim)
if pretrained_path:
self.load_state_dict(torch.load(pretrained_path, map_location=torch.device('cpu'))) # Load to CPU first
print(f"Loaded SigLIP image encoder from {pretrained_path}")
else:
print("Initialized SigLIP image encoder without pretrained weights.")
def forward(self, image):
features = self.model(image)
embedding = self.projection(features)
return embedding
# 2. Load Models and Tokenizer
phi3_model_path = "QuantFactory/Phi-3-mini-4k-instruct-GGUF" # Path to your quantized Phi-3 GGUF model
peft_model_path = "./qlora_phi3_model"
image_model_name = 'resnet50'
image_embed_dim = 512
siglip_pretrained_path = "image_encoder.pth" # Path to your pretrained SigLIP model
#device = torch.device("cpu") # Force CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load Tokenizer (using a compatible tokenizer)
text_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) # Or a compatible tokenizer
text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
# Image Transformations
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load SigLIP Image Encoder
image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
image_encoder.eval() # Set to evaluation mode
# Load Phi-3 model using llama.cpp
#base_model = Llama(
# model_path=phi3_model_path,
# n_gpu_layers=0, # Ensure no GPU usage
# n_ctx=2048, # Adjust context length as needed
# verbose=True,
#)
#base_model = Llama.from_pretrained(
# repo_id="QuantFactory/Phi-3-mini-4k-instruct-GGUF",
# filename="Phi-3-mini-4k-instruct.Q2_K.gguf",
# n_gpu_layers=0,
# n_ctx=2048,
# verbose=True
#)
base_model_name="microsoft/Phi-3-mini-4k-instruct"
#device = "cuda"
#base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32, device_map={"": device})
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32, device_map="auto")
# Load and merge
model = PeftModel.from_pretrained(base_model, peft_model_path, offload_dir='./offload')
model = model.merge_and_unload()
print("phi-3 model loaded sucessfully")
# 3. Inference Function
@spaces.GPU
def predict(image, question):
"""
Takes an image and a question as input and returns an answer.
"""
if image is None or question is None or question == "":
return "Please provide both an image and a question."
try:
image = Image.fromarray(image).convert("RGB")
image = image_transform(image).unsqueeze(0).to(device)
# Get image embeddings
with torch.no_grad():
image_embeddings = image_encoder(image)
# Flatten the image embeddings for simplicity
image_embeddings = image_embeddings.flatten().tolist()
# Create the prompt with image embeddings
prompt = f"Question: {question}\nImage Embeddings: {image_embeddings}\nAnswer:"
# Generate answer using llama.cpp
output = model(
prompt,
max_tokens=128,
stop=["Q:", "\n"],
echo=False,
)
answer = output["choices"][0]["text"].strip()
return answer
except Exception as e:
return f"An error occurred: {str(e)}"
# 4. Gradio Interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(label="Upload an Image"),
gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
],
outputs=gr.Textbox(label="Answer"),
title="Image Question Answering with Phi-3 and SigLIP (CPU)",
description="Ask questions about an image and get answers powered by Phi-3 (llama.cpp) and SigLIP.",
examples=[
["cat_0006.png", "Create a interesting story about this image?"],
["bird_0004.png", "Can you describe this image?"],
["truck_0003.png", "Elaborate the setting of the image"],
["ship_0007.png", "Explain the purpose of image"]
]
)
# 5. Launch the App
if __name__ == "__main__":
iface.launch()