IbrahimHasani's picture
Update app.py
e7acfb9 verified
raw
history blame contribute delete
No virus
4.2 kB
import gradio as gr
from PIL import ImageFilter, Image
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the CLIP-ViT model
checkpoint = "openai/clip-vit-large-patch14-336"
model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint)
model = model.to(device)
processor = AutoProcessor.from_pretrained(checkpoint)
def classify_image(image, candidate_labels):
messages = []
candidate_labels = [label.strip() for label in candidate_labels.split(",") if label.strip()] + ["other"]
if len(candidate_labels) == 1:
candidate_labels.append("other")
# Blur the image
image = image.filter(ImageFilter.GaussianBlur(radius=5))
# Process the image and candidate labels
inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Get model's output
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits_per_image[0]
probs = logits.softmax(dim=-1).cpu().numpy()
# Organize results
results = [
{"score": float(score), "label": candidate_label}
for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0])
]
# Convert results to list of lists for Dataframe
results_for_df = [[res['label'], res['score']] for res in results]
# Decision-making logic
top_label = results[0]["label"]
second_label = results[1]["label"] if len(results) > 1 else "None"
# Add messages to understand the scores
messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}")
messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}" if len(results) > 1 else "")
# Example decision logic for specific scenarios (can be customized further)
if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other":
messages.append("Triggered the new 0.58 check!")
result = True
elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90:
messages.append("Triggered the 90% combined check!")
result = True
elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95:
messages.append("Triggered the 90% reverse order check!")
result = True
else:
result = False
return result, top_label, results_for_df, messages
# Default values
default_labels = "human with beverage,human,beverage"
default_image_path = "F50xXeBbcAA0IIx.jpeg"
# Load default image
default_image = Image.open(default_image_path)
iface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Upload an Image", value=default_image),
gr.Textbox(label="Candidate Labels (comma separated)", value=default_labels)
],
outputs=[
gr.Label(label="Result"),
gr.Textbox(label="Top Label"),
gr.Dataframe(headers=["Label", "Score"], label="Details"),
gr.Textbox(label="Messages")
],
title="General Action Classifier",
description="""
**Instructions:**
1. **Upload an Image**: Drag and drop an image or click to upload an image file. A default image is provided.
2. **Enter Candidate Labels**:
- Provide candidate labels separated by commas.
- For example: `human with beverage,human,beverage`
- The label "other" will automatically be added to the list of candidate labels.
- You can enter just one label, and "other" will still be added automatically. Default labels are provided.
3. **View Results**:
- The result will indicate whether the specified action (top label) is present in the image.
- Detailed scores for each label will be displayed in a table.
- Additional messages explaining the decision process will also be shown.
"""
)
if __name__ == "__main__":
iface.launch()