Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import ImageFilter, Image | |
| from transformers import AutoModelForZeroShotImageClassification, AutoProcessor | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Initialize the CLIP-ViT model | |
| checkpoint = "openai/clip-vit-large-patch14-336" | |
| model = AutoModelForZeroShotImageClassification.from_pretrained(checkpoint) | |
| model = model.to(device) | |
| processor = AutoProcessor.from_pretrained(checkpoint) | |
| def classify_image(image, candidate_labels): | |
| messages = [] | |
| candidate_labels = [label.strip() for label in candidate_labels.split(",") if label.strip()] + ["other"] | |
| if len(candidate_labels) == 1: | |
| candidate_labels.append("other") | |
| # Blur the image | |
| image = image.filter(ImageFilter.GaussianBlur(radius=5)) | |
| # Process the image and candidate labels | |
| inputs = processor(images=image, text=candidate_labels, return_tensors="pt", padding=True) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # Get model's output | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits_per_image[0] | |
| probs = logits.softmax(dim=-1).cpu().numpy() | |
| # Organize results | |
| results = [ | |
| {"score": float(score), "label": candidate_label} | |
| for score, candidate_label in sorted(zip(probs, candidate_labels), key=lambda x: -x[0]) | |
| ] | |
| # Convert results to list of lists for Dataframe | |
| results_for_df = [[res['label'], res['score']] for res in results] | |
| # Decision-making logic | |
| top_label = results[0]["label"] | |
| second_label = results[1]["label"] if len(results) > 1 else "None" | |
| # Add messages to understand the scores | |
| messages.append(f"Top label: {top_label} with score: {results[0]['score']:.2f}") | |
| messages.append(f"Second label: {second_label} with score: {results[1]['score']:.2f}" if len(results) > 1 else "") | |
| # Example decision logic for specific scenarios (can be customized further) | |
| if top_label == candidate_labels[0] and results[0]["score"] >= 0.58 and second_label != "other": | |
| messages.append("Triggered the new 0.58 check!") | |
| result = True | |
| elif top_label == candidate_labels[0] and second_label in candidate_labels[:-1] and (results[0]['score'] + results[1]['score']) >= 0.90: | |
| messages.append("Triggered the 90% combined check!") | |
| result = True | |
| elif top_label == candidate_labels[1] and second_label == candidate_labels[0] and (results[0]['score'] + results[1]['score']) >= 0.95: | |
| messages.append("Triggered the 90% reverse order check!") | |
| result = True | |
| else: | |
| result = False | |
| return result, top_label, results_for_df, messages | |
| # Default values | |
| default_labels = "human with beverage,human,beverage" | |
| default_image_path = "F50xXeBbcAA0IIx.jpeg" | |
| # Load default image | |
| default_image = Image.open(default_image_path) | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload an Image", value=default_image), | |
| gr.Textbox(label="Candidate Labels (comma separated)", value=default_labels) | |
| ], | |
| outputs=[ | |
| gr.Label(label="Result"), | |
| gr.Textbox(label="Top Label"), | |
| gr.Dataframe(headers=["Label", "Score"], label="Details"), | |
| gr.Textbox(label="Messages") | |
| ], | |
| title="General Action Classifier", | |
| description=""" | |
| **Instructions:** | |
| 1. **Upload an Image**: Drag and drop an image or click to upload an image file. A default image is provided. | |
| 2. **Enter Candidate Labels**: | |
| - Provide candidate labels separated by commas. | |
| - For example: `human with beverage,human,beverage` | |
| - The label "other" will automatically be added to the list of candidate labels. | |
| - You can enter just one label, and "other" will still be added automatically. Default labels are provided. | |
| 3. **View Results**: | |
| - The result will indicate whether the specified action (top label) is present in the image. | |
| - Detailed scores for each label will be displayed in a table. | |
| - Additional messages explaining the decision process will also be shown. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |