IbrahimHasani's picture
Update app.py
28238e4 verified
raw
history blame
3.69 kB
import gradio as gr
from PIL import ImageFilter, Image
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
import torch
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CLIP-ViT model
checkpoint = "openai/clip-vit-large-patch14-336"
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
model = model.to(device)
processor = AutoProcessor.from_pretrained(checkpoint)
def classify_image(image, candidate_labels):
messages = []
candidate_labels = [label.strip() for label in candidate_labels.split(",")] + ["other"]
# Blur the image
image = image.filter(ImageFilter.GaussianBlur(radius=5))
# Process the image and candidate labels
inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Get model's output
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits_per_image[0]
probs = logits.softmax(dim=-1).cpu().numpy()
# Organize results
results = [
{"score": float(score), "label": candidate_label}
for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
]
# Convert results to list of lists for Dataframe
results_for_df = [[res['label'], res['score']] for res in results]
# Decision-making logic
top_label = results[0]["label"]
second_label = results[1]["label"]
# Add messages to understand the scores
messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}")
messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}")
# Example decision logic for specific scenarios (can be customized further)
if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other":
messages.append("Triggered the new 0.58 check!")
result = True
elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90:
messages.append("Triggered the 90% combined check!")
result = True
elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95:
messages.append("Triggered the 90% reverse order check!")
result = True
else:
result = False
return result, top_label, results_for_df, messages
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Candidate Labels (comma separated)")
],
outputs=[
gr.Label(label="Result"),
gr.Textbox(label="Top Label"),
gr.Dataframe(headers=["Label", "Score"], label="Details"),
gr.Textbox(label="Messages")
],
title="General Action Classifier",
description="""
**Instructions:**
1. **Upload an Image**: Drag and drop an image or click to upload an image file.
2. **Enter Candidate Labels**:
- Provide candidate labels separated by commas.
- For example: `human with beverage,human,beverage`
- The label "other" will automatically be added to the list of candidate labels.
3. **View Results**:
- The result will indicate whether the specified action (top label) is present in the image.
- Detailed scores for each label will be displayed in a table.
- Additional messages explaining the decision process will also be shown.
"""
)
if __name__ == "__main__":
iface.launch()