Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import LlavaForConditionalGeneration, LlavaProcessor | |
from PIL import Image | |
import torch | |
import io | |
import warnings | |
warnings.filterwarnings("ignore") | |
def safe_convert_image(img): | |
if isinstance(img, Image.Image): | |
return img | |
elif isinstance(img, bytes): | |
return Image.open(io.BytesIO(img)) | |
elif hasattr(img, "read"): | |
return Image.open(img) | |
else: | |
raise ValueError("Format d'image non pris en charge.") | |
# Chargement du modèle et du processeur | |
model_id = "llava-hf/llava-1.5-7b-hf" | |
processor = LlavaProcessor.from_pretrained(model_id) | |
model = LlavaForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Fonction de VQA | |
def vqa_llava(image, question): | |
try: | |
image = safe_convert_image(image) | |
# ✅ Prompt spécifique à LLaVA | |
prompt = f"<image>\nUSER: {question.strip()}\nASSISTANT:" | |
# Préparation des entrées | |
inputs = processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(model.device) | |
# Génération de la réponse | |
generate_ids = model.generate(**inputs, max_new_tokens=100) | |
response = processor.batch_decode(generate_ids, skip_special_tokens=True)[0] | |
return response.replace(prompt, "").strip() | |
except Exception as e: | |
return f"❌ Erreur : {str(e)}" | |
# Interface Gradio | |
interface = gr.Interface( | |
fn=vqa_llava, | |
inputs=[ | |
gr.Image(type="pil", label="🖼️ Image"), | |
gr.Textbox(lines=2, label="❓ Question (en anglais)") | |
], | |
outputs=gr.Textbox(label="💬 Réponse"), | |
title="🔎 Visual Question Answering avec LLaVA", | |
description="Téléverse une image et pose une question visuelle (en anglais). Le modèle LLaVA-1.5-7B y répondra." | |
) | |
if __name__ == "__main__": | |
interface.launch() | |