christiankeleze's picture
Update app.py
8fb1603 verified
# import cv2
# import numpy as np
# import gradio as gr
# from PIL import Image
# import base64
# import io
# # Load class labels (one per line).
# with open("synset_words.txt", "r") as f:
# classes = [line.strip() for line in f.readlines()]
# # Load the prebuilt MobileNetV2 model in ONNX format.
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx")
# if net.empty():
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.")
# def classify_image(image):
# """
# Processes an input image using MobileNetV2 via OpenCV DNN,
# and returns an explanation string.
# The input can be:
# - A PIL Image (when uploaded from the web)
# - A dictionary with keys "data", "name", and "mime_type" (from the PyQt client)
# """
# # If the input is a dictionary, decode the base64-encoded image.
# if isinstance(image, dict):
# try:
# img_bytes = base64.b64decode(image.get("data", ""))
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# image = pil_img
# except Exception as e:
# return f"Error decoding image: {e}"
# # Otherwise, assume image is already a PIL Image.
# if not isinstance(image, Image.Image):
# return "Invalid image input."
# # Convert the PIL image to a NumPy array.
# image_np = np.array(image)
# if image_np is None or image_np.size == 0:
# return "Invalid image input."
# # Create a blob from the image.
# blob = cv2.dnn.blobFromImage(
# image_np,
# scalefactor=1.0/255,
# size=(224, 224),
# mean=(0.485, 0.456, 0.406),
# swapRB=True,
# crop=False
# )
# blob = blob.astype(np.float32)
# mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
# std = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
# blob = (blob - mean) / std
# net.setInput(blob)
# preds = net.forward().flatten()
# top_idx = int(np.argmax(preds))
# prob = preds[top_idx]
# label = classes[top_idx] if top_idx < len(classes) else "Unknown"
# explanation = f"This image is predicted as '{label}' with a confidence of {prob:.2f}."
# return explanation
# iface = gr.Interface(
# fn=classify_image,
# inputs=gr.Image(type="pil", label="Upload Image"), # Using PIL as input
# outputs=gr.Textbox(label="Prediction"),
# title="Image Explanation using MobileNetV2 (ONNX)",
# description=(
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) "
# "to classify an image and explain what it is about. Upload an image and see the prediction."
# )
# )
# if __name__ == "__main__":
# iface.launch(show_error=True, show_api=True)
# import cv2
# import numpy as np
# import gradio as gr
# from PIL import Image
# import base64
# import io
# import cv2
# import numpy as np
# import gradio as gr
# from PIL import Image
# import base64
# import io
# import re
# import google.generativeai as genai
# # --- Gemini API functions ---
# GEMINI_API_KEY = "AIzaSyApckR1cL6WH5cRjXXwLwnBxwP43K0hmZ0" # your Gemini API key
# def clean_text(text):
# """Remove '**' and special symbols; keep alphanumerics, whitespace, and basic punctuation."""
# text = text.replace("**", "")
# text = re.sub(r'[^\w\s.,?!]', '', text)
# return text
# def generate_answer(question, max_length=100):
# """Generate answer via Gemini API, clean it, and limit its length."""
# genai.configure(api_key=GEMINI_API_KEY)
# model = genai.GenerativeModel('gemini-2.0-flash')
# response = model.generate_content(
# f"{question} make just be straight forward to answer no much explanation unless where needed to learn indepth from the user"
# )
# cleaned_response = clean_text(response.text)
# words = cleaned_response.split()
# if len(words) > max_length:
# cleaned_response = ' '.join(words[:max_length])
# return cleaned_response
# # Load class labels (one per line).
# with open("synset_words.txt", "r") as f:
# classes = [line.strip() for line in f.readlines()]
# # Load the prebuilt MobileNetV2 model in ONNX format.
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx")
# if net.empty():
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.")
# def classify_image(image):
# """
# Processes an input image using MobileNetV2 via OpenCV DNN,
# and returns an explanation string.
# The input can be:
# - A PIL Image (when uploaded from the web)
# - A dictionary with keys "data", "name", and "mime_type" (from the PyQt client)
# """
# # If the input is a dictionary, decode the base64-encoded image.
# if isinstance(image, dict):
# try:
# img_bytes = base64.b64decode(image.get("data", ""))
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# image = pil_img
# except Exception as e:
# return f"Error decoding image: {e}"
# # Otherwise, assume image is already a PIL Image.
# if not isinstance(image, Image.Image):
# return "Invalid image input."
# # Convert the PIL image to a NumPy array.
# image_np = np.array(image)
# if image_np is None or image_np.size == 0:
# return "Invalid image input."
# # Create a blob from the image.
# blob = cv2.dnn.blobFromImage(
# image_np,
# scalefactor=1.0/255,
# size=(224, 224),
# mean=(0.485, 0.456, 0.406),
# swapRB=True,
# crop=False
# )
# blob = blob.astype(np.float32)
# mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
# std = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
# blob = (blob - mean) / std
# net.setInput(blob)
# preds = net.forward().flatten()
# top_idx = int(np.argmax(preds))
# prob = preds[top_idx]
# label = classes[top_idx] if top_idx < len(classes) else "Unknown"
# question = f" just tell me about this {label}'"
# return str(generate_answer(question=question))
# iface = gr.Interface(
# fn=classify_image,
# inputs=gr.Image(type="pil", label="Upload Image"), # Using PIL as input
# outputs=gr.Textbox(label="Prediction"),
# title="Image Explanation using MobileNetV2 (ONNX)",
# description=(
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) "
# "to classify an image and explain what it is about. Upload an image and see the prediction."
# )
# )
# if __name__ == "__main__":
# iface.launch(show_error=True, show_api=True)
# # --- MobileNetV2 classification ---
# # Load class labels (one per line).
# with open("synset_words.txt", "r") as f:
# classes = [line.strip() for line in f.readlines()]
# # Load the prebuilt MobileNetV2 model in ONNX format.
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx")
# if net.empty():
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.")
# def classify_image(pil_img):
# """
# Processes an input image (PIL Image) using MobileNetV2 via OpenCV DNN,
# then refines the explanation using the Gemini API.
# Accepts:
# - A PIL Image (if uploaded via the web)
# - A dictionary (if sent from the PyQt client) with keys "data", "name", and "mime_type".
# """
# # If input is a dictionary, decode it.
# if isinstance(pil_img, dict):
# try:
# img_bytes = base64.b64decode(pil_img.get("data", ""))
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# except Exception as e:
# return f"Error decoding image: {e}"
# # Check that we have a PIL Image.
# if not isinstance(pil_img, Image.Image):
# return "Invalid image input."
# # Convert PIL image to NumPy array.
# image_np = np.array(pil_img)
# if image_np is None or image_np.size == 0:
# return "Invalid image input."
# # Create blob from image.
# blob = cv2.dnn.blobFromImage(
# image_np,
# scalefactor=1.0/255,
# size=(224, 224),
# mean=(0.485, 0.456, 0.406),
# swapRB=True,
# crop=False
# )
# blob = blob.astype(np.float32)
# mean_arr = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
# std_arr = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
# blob = (blob - mean_arr) / std_arr
# net.setInput(blob)
# preds = net.forward().flatten()
# top_idx = int(np.argmax(preds))
# prob = preds[top_idx]
# label = classes[top_idx] if top_idx < len(classes) else "Unknown"
# # Initial explanation from MobileNetV2.
# explanation = f"please explain this image with name {label} "
# # Refine explanation using Gemini API.
# refined_explanation = generate_answer(explanation, max_length=100)
# return refined_explanation
# iface = gr.Interface(
# fn=classify_image,
# inputs=gr.Image(type="pil", label="Upload Image"), # Expects a PIL image
# outputs=gr.Textbox(label="Prediction"),
# title="Image Explanation using MobileNetV2 (ONNX) & Gemini",
# description=(
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) "
# "to classify an image. The prediction is refined using Gemini API to provide a straightforward answer. "
# "Upload an image and see the result."
# )
# )
# if __name__ == "__main__":
# iface.launch(show_error=True, show_api=True)
"""imports"""
import os
from dotenv import load_dotenv
load_dotenv()
import gradio as gr
from PIL import Image
import re
import base64
import io
import google.generativeai as genai
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
# --- Load tokens from environment ---
HF_TOKEN = os.getenv("ACCESS_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# --- Gemini API functions ---
def clean_text(text):
"""Remove unwanted symbols; keep alphanumerics, whitespace, and basic punctuation."""
text = text.replace("**", "")
text = re.sub(r'[^\w\s.,?!]', '', text)
return text
def generate_answer(question, max_length=100):
"""Generate answer via Gemini API, clean it, and limit its length."""
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')
response = model.generate_content(
f"{question} be concise and direct."
)
cleaned_response = clean_text(response.text)
words = cleaned_response.split()
if len(words) > max_length:
cleaned_response = ' '.join(words[:max_length])
return cleaned_response
# --- BLIP‑2 setup for image captioning ---
print("Loading BLIP‑2 model and processor...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", token=HF_TOKEN)
model_blip2 = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl", token=HF_TOKEN)
print("BLIP‑2 loaded.")
def caption_image(pil_img):
"""
Generate a caption for the image using BLIP‑2.
Expects a PIL image.
"""
# Optionally, resize the image to lower resolution for faster inference.
pil_img = pil_img.resize((480, 480))
inputs = processor(pil_img, return_tensors="pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_blip2.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
output_ids = model_blip2.generate(**inputs)
caption = processor.decode(output_ids[0], skip_special_tokens=True)
return caption
def process_image(image):
"""
Processes an input image (PIL image) using BLIP‑2 to generate a caption,
cleans the caption to remove internal identifiers, and then uses the Gemini API
to generate a refined explanation.
"""
if not isinstance(image, Image.Image):
return "Invalid image input."
caption = caption_image(image)
# Optionally, remove internal identifiers (e.g., codes like "n04285008")
cleaned_caption = re.sub(r"n\d+", "", caption).strip()
# question = f"Explain this image based on the caption: {cleaned_caption}."
# refined_explanation = generate_answer(question=question, max_length=100)
return cleaned_caption
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Textbox(label="Explanation"),
title="Image Explanation using BLIP‑2 & Gemini",
description=(
"This API utilizes BLIP-2 for image captioning, providing knowledge and insights about an image. Simply upload an image to receive a concise explanation of its content. This API is open-source, allowing all developers to test it and share their feedback, helping to ensure seamless integration into various systems."
)
)
if __name__ == "__main__":
iface.launch(show_error=True, show_api=True)