File size: 1,755 Bytes
a3d0c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import cv2
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPVisionModel
from PIL import Image
from torch import nn
import requests
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

# ... (rest of your code remains the same)

def process_image_classification(image):
    model, processor, reverse_mapping, device = load_model()
    
    # Convert image to PIL Image
    image = Image.fromarray(image)
    
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs.pixel_values.to(device)
    
    with torch.no_grad():
        logits, attentions = model(pixel_values, output_attentions=True)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        prediction = torch.argmax(probs).item()
    
    # Generate attention map
    attention_map = get_attention_map(attentions)
    
    visualization = apply_heatmap(image, attention_map)
    
    card_name = reverse_mapping[prediction]
    confidence = probs[0][prediction].item()
    
    # Convert back to RGB for matplotlib display
    visualization_rgb = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)
    
    return visualization_rgb, card_name, confidence

def gradio_interface():
    gr_interface = gr.Interface(
        fn=process_image_classification,
        inputs=gr.inputs.Image(type="numpy"),
        outputs=[
            gr.outputs.Image(label="Heatmap Plot"),
            gr.outputs.Textbox(label="Predicted Card"),
            gr.outputs.Textbox(label="Confidence")
        ],
        title="Uno Card Recognizer",
        description="Upload an image or use your webcam to recognize an Uno card."
    )
    gr_interface.launch()

if __name__ == "__main__":
    gradio_interface()