# import cv2 | |
# import numpy as np | |
# import gradio as gr | |
# from PIL import Image | |
# import base64 | |
# import io | |
# # Load class labels (one per line). | |
# with open("synset_words.txt", "r") as f: | |
# classes = [line.strip() for line in f.readlines()] | |
# # Load the prebuilt MobileNetV2 model in ONNX format. | |
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx") | |
# if net.empty(): | |
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.") | |
# def classify_image(image): | |
# """ | |
# Processes an input image using MobileNetV2 via OpenCV DNN, | |
# and returns an explanation string. | |
# The input can be: | |
# - A PIL Image (when uploaded from the web) | |
# - A dictionary with keys "data", "name", and "mime_type" (from the PyQt client) | |
# """ | |
# # If the input is a dictionary, decode the base64-encoded image. | |
# if isinstance(image, dict): | |
# try: | |
# img_bytes = base64.b64decode(image.get("data", "")) | |
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# image = pil_img | |
# except Exception as e: | |
# return f"Error decoding image: {e}" | |
# # Otherwise, assume image is already a PIL Image. | |
# if not isinstance(image, Image.Image): | |
# return "Invalid image input." | |
# # Convert the PIL image to a NumPy array. | |
# image_np = np.array(image) | |
# if image_np is None or image_np.size == 0: | |
# return "Invalid image input." | |
# # Create a blob from the image. | |
# blob = cv2.dnn.blobFromImage( | |
# image_np, | |
# scalefactor=1.0/255, | |
# size=(224, 224), | |
# mean=(0.485, 0.456, 0.406), | |
# swapRB=True, | |
# crop=False | |
# ) | |
# blob = blob.astype(np.float32) | |
# mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) | |
# std = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) | |
# blob = (blob - mean) / std | |
# net.setInput(blob) | |
# preds = net.forward().flatten() | |
# top_idx = int(np.argmax(preds)) | |
# prob = preds[top_idx] | |
# label = classes[top_idx] if top_idx < len(classes) else "Unknown" | |
# explanation = f"This image is predicted as '{label}' with a confidence of {prob:.2f}." | |
# return explanation | |
# iface = gr.Interface( | |
# fn=classify_image, | |
# inputs=gr.Image(type="pil", label="Upload Image"), # Using PIL as input | |
# outputs=gr.Textbox(label="Prediction"), | |
# title="Image Explanation using MobileNetV2 (ONNX)", | |
# description=( | |
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) " | |
# "to classify an image and explain what it is about. Upload an image and see the prediction." | |
# ) | |
# ) | |
# if __name__ == "__main__": | |
# iface.launch(show_error=True, show_api=True) | |
# import cv2 | |
# import numpy as np | |
# import gradio as gr | |
# from PIL import Image | |
# import base64 | |
# import io | |
# import cv2 | |
# import numpy as np | |
# import gradio as gr | |
# from PIL import Image | |
# import base64 | |
# import io | |
# import re | |
# import google.generativeai as genai | |
# # --- Gemini API functions --- | |
# GEMINI_API_KEY = "AIzaSyApckR1cL6WH5cRjXXwLwnBxwP43K0hmZ0" # your Gemini API key | |
# def clean_text(text): | |
# """Remove '**' and special symbols; keep alphanumerics, whitespace, and basic punctuation.""" | |
# text = text.replace("**", "") | |
# text = re.sub(r'[^\w\s.,?!]', '', text) | |
# return text | |
# def generate_answer(question, max_length=100): | |
# """Generate answer via Gemini API, clean it, and limit its length.""" | |
# genai.configure(api_key=GEMINI_API_KEY) | |
# model = genai.GenerativeModel('gemini-2.0-flash') | |
# response = model.generate_content( | |
# f"{question} make just be straight forward to answer no much explanation unless where needed to learn indepth from the user" | |
# ) | |
# cleaned_response = clean_text(response.text) | |
# words = cleaned_response.split() | |
# if len(words) > max_length: | |
# cleaned_response = ' '.join(words[:max_length]) | |
# return cleaned_response | |
# # Load class labels (one per line). | |
# with open("synset_words.txt", "r") as f: | |
# classes = [line.strip() for line in f.readlines()] | |
# # Load the prebuilt MobileNetV2 model in ONNX format. | |
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx") | |
# if net.empty(): | |
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.") | |
# def classify_image(image): | |
# """ | |
# Processes an input image using MobileNetV2 via OpenCV DNN, | |
# and returns an explanation string. | |
# The input can be: | |
# - A PIL Image (when uploaded from the web) | |
# - A dictionary with keys "data", "name", and "mime_type" (from the PyQt client) | |
# """ | |
# # If the input is a dictionary, decode the base64-encoded image. | |
# if isinstance(image, dict): | |
# try: | |
# img_bytes = base64.b64decode(image.get("data", "")) | |
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# image = pil_img | |
# except Exception as e: | |
# return f"Error decoding image: {e}" | |
# # Otherwise, assume image is already a PIL Image. | |
# if not isinstance(image, Image.Image): | |
# return "Invalid image input." | |
# # Convert the PIL image to a NumPy array. | |
# image_np = np.array(image) | |
# if image_np is None or image_np.size == 0: | |
# return "Invalid image input." | |
# # Create a blob from the image. | |
# blob = cv2.dnn.blobFromImage( | |
# image_np, | |
# scalefactor=1.0/255, | |
# size=(224, 224), | |
# mean=(0.485, 0.456, 0.406), | |
# swapRB=True, | |
# crop=False | |
# ) | |
# blob = blob.astype(np.float32) | |
# mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) | |
# std = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) | |
# blob = (blob - mean) / std | |
# net.setInput(blob) | |
# preds = net.forward().flatten() | |
# top_idx = int(np.argmax(preds)) | |
# prob = preds[top_idx] | |
# label = classes[top_idx] if top_idx < len(classes) else "Unknown" | |
# question = f" just tell me about this {label}'" | |
# return str(generate_answer(question=question)) | |
# iface = gr.Interface( | |
# fn=classify_image, | |
# inputs=gr.Image(type="pil", label="Upload Image"), # Using PIL as input | |
# outputs=gr.Textbox(label="Prediction"), | |
# title="Image Explanation using MobileNetV2 (ONNX)", | |
# description=( | |
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) " | |
# "to classify an image and explain what it is about. Upload an image and see the prediction." | |
# ) | |
# ) | |
# if __name__ == "__main__": | |
# iface.launch(show_error=True, show_api=True) | |
# # --- MobileNetV2 classification --- | |
# # Load class labels (one per line). | |
# with open("synset_words.txt", "r") as f: | |
# classes = [line.strip() for line in f.readlines()] | |
# # Load the prebuilt MobileNetV2 model in ONNX format. | |
# net = cv2.dnn.readNetFromONNX("mobilenetv2-7.onnx") | |
# if net.empty(): | |
# raise ValueError("Could not load the ONNX model. Check your 'mobilenetv2-7.onnx' file.") | |
# def classify_image(pil_img): | |
# """ | |
# Processes an input image (PIL Image) using MobileNetV2 via OpenCV DNN, | |
# then refines the explanation using the Gemini API. | |
# Accepts: | |
# - A PIL Image (if uploaded via the web) | |
# - A dictionary (if sent from the PyQt client) with keys "data", "name", and "mime_type". | |
# """ | |
# # If input is a dictionary, decode it. | |
# if isinstance(pil_img, dict): | |
# try: | |
# img_bytes = base64.b64decode(pil_img.get("data", "")) | |
# pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# except Exception as e: | |
# return f"Error decoding image: {e}" | |
# # Check that we have a PIL Image. | |
# if not isinstance(pil_img, Image.Image): | |
# return "Invalid image input." | |
# # Convert PIL image to NumPy array. | |
# image_np = np.array(pil_img) | |
# if image_np is None or image_np.size == 0: | |
# return "Invalid image input." | |
# # Create blob from image. | |
# blob = cv2.dnn.blobFromImage( | |
# image_np, | |
# scalefactor=1.0/255, | |
# size=(224, 224), | |
# mean=(0.485, 0.456, 0.406), | |
# swapRB=True, | |
# crop=False | |
# ) | |
# blob = blob.astype(np.float32) | |
# mean_arr = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) | |
# std_arr = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) | |
# blob = (blob - mean_arr) / std_arr | |
# net.setInput(blob) | |
# preds = net.forward().flatten() | |
# top_idx = int(np.argmax(preds)) | |
# prob = preds[top_idx] | |
# label = classes[top_idx] if top_idx < len(classes) else "Unknown" | |
# # Initial explanation from MobileNetV2. | |
# explanation = f"please explain this image with name {label} " | |
# # Refine explanation using Gemini API. | |
# refined_explanation = generate_answer(explanation, max_length=100) | |
# return refined_explanation | |
# iface = gr.Interface( | |
# fn=classify_image, | |
# inputs=gr.Image(type="pil", label="Upload Image"), # Expects a PIL image | |
# outputs=gr.Textbox(label="Prediction"), | |
# title="Image Explanation using MobileNetV2 (ONNX) & Gemini", | |
# description=( | |
# "This API uses OpenCV's DNN module with a prebuilt MobileNetV2 model (ONNX format) " | |
# "to classify an image. The prediction is refined using Gemini API to provide a straightforward answer. " | |
# "Upload an image and see the result." | |
# ) | |
# ) | |
# if __name__ == "__main__": | |
# iface.launch(show_error=True, show_api=True) | |
"""imports""" | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
import gradio as gr | |
from PIL import Image | |
import re | |
import base64 | |
import io | |
import google.generativeai as genai | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
import torch | |
# --- Load tokens from environment --- | |
HF_TOKEN = os.getenv("ACCESS_TOKEN") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
# --- Gemini API functions --- | |
def clean_text(text): | |
"""Remove unwanted symbols; keep alphanumerics, whitespace, and basic punctuation.""" | |
text = text.replace("**", "") | |
text = re.sub(r'[^\w\s.,?!]', '', text) | |
return text | |
def generate_answer(question, max_length=100): | |
"""Generate answer via Gemini API, clean it, and limit its length.""" | |
genai.configure(api_key=GEMINI_API_KEY) | |
model = genai.GenerativeModel('gemini-2.0-flash') | |
response = model.generate_content( | |
f"{question} be concise and direct." | |
) | |
cleaned_response = clean_text(response.text) | |
words = cleaned_response.split() | |
if len(words) > max_length: | |
cleaned_response = ' '.join(words[:max_length]) | |
return cleaned_response | |
# --- BLIP‑2 setup for image captioning --- | |
print("Loading BLIP‑2 model and processor...") | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", token=HF_TOKEN) | |
model_blip2 = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl", token=HF_TOKEN) | |
print("BLIP‑2 loaded.") | |
def caption_image(pil_img): | |
""" | |
Generate a caption for the image using BLIP‑2. | |
Expects a PIL image. | |
""" | |
# Optionally, resize the image to lower resolution for faster inference. | |
pil_img = pil_img.resize((480, 480)) | |
inputs = processor(pil_img, return_tensors="pt") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_blip2.to(device) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
output_ids = model_blip2.generate(**inputs) | |
caption = processor.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
def process_image(image): | |
""" | |
Processes an input image (PIL image) using BLIP‑2 to generate a caption, | |
cleans the caption to remove internal identifiers, and then uses the Gemini API | |
to generate a refined explanation. | |
""" | |
if not isinstance(image, Image.Image): | |
return "Invalid image input." | |
caption = caption_image(image) | |
# Optionally, remove internal identifiers (e.g., codes like "n04285008") | |
cleaned_caption = re.sub(r"n\d+", "", caption).strip() | |
# question = f"Explain this image based on the caption: {cleaned_caption}." | |
# refined_explanation = generate_answer(question=question, max_length=100) | |
return cleaned_caption | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Textbox(label="Explanation"), | |
title="Image Explanation using BLIP‑2 & Gemini", | |
description=( | |
"This API utilizes BLIP-2 for image captioning, providing knowledge and insights about an image. Simply upload an image to receive a concise explanation of its content. This API is open-source, allowing all developers to test it and share their feedback, helping to ensure seamless integration into various systems." | |
) | |
) | |
if __name__ == "__main__": | |
iface.launch(show_error=True, show_api=True) | |