Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import torch | |
| from PIL import Image | |
| model_name = 'e1010101/vit-384-tongue-image' | |
| processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-384") | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_name, | |
| num_labels=3, | |
| problem_type="multi_label_classification", | |
| ignore_mismatched_sizes=True, | |
| id2label={0: 'Crack', 1: 'Red-Dots', 2: 'Toothmark'}, | |
| label2id={'Crack': 0, 'Red-Dots': 1, 'Toothmark': 2} | |
| ) | |
| def classify_image(image, threshold=0.5): | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Apply sigmoid for multi-label classification | |
| probs = torch.sigmoid(logits)[0].numpy() | |
| # Get label names | |
| labels = model.config.id2label.values() | |
| # Create a dictionary of labels and probabilities | |
| result = {label: float(prob) for label, prob in zip(labels, probs)} | |
| # Sort results by probability | |
| result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
| return result | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Slider(minimum=0, maximum=1, value=0.5, label="Probability Threshold") | |
| ], | |
| outputs=gr.Label(num_top_classes=None), | |
| title="Multi-Label Image Classification", | |
| description="Upload an image to get classification results." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |