Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
2 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def classify(image):
|
5 |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
6 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
7 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
8 |
-
|
9 |
-
|
|
|
10 |
# model predicts one of the 1000 ImageNet classes
|
11 |
predicted_class_idx = logits.argmax(-1).item()
|
12 |
return model.config.id2label[predicted_class_idx]
|
@@ -17,6 +23,12 @@ def image2speech(image):
|
|
17 |
|
18 |
fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
|
19 |
|
20 |
-
app = gr.Interface(fn=image2speech,
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
app.launch()
|
|
|
1 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
2 |
import gradio as gr
|
3 |
+
from datasets import load_dataset
|
4 |
+
import torch
|
5 |
+
|
6 |
+
dataset = load_dataset("cifar100")
|
7 |
+
image = dataset["train"]["fine_label"]
|
8 |
|
9 |
def classify(image):
|
10 |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
11 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
12 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
13 |
+
with torch.no_grad():
|
14 |
+
outputs = model(**inputs)
|
15 |
+
logits = outputs.logits
|
16 |
# model predicts one of the 1000 ImageNet classes
|
17 |
predicted_class_idx = logits.argmax(-1).item()
|
18 |
return model.config.id2label[predicted_class_idx]
|
|
|
23 |
|
24 |
fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
|
25 |
|
26 |
+
app = gr.Interface(fn=image2speech,
|
27 |
+
inputs="image",
|
28 |
+
title="Image to speech",
|
29 |
+
description="Classifies and image and tell you what is it",
|
30 |
+
examples=["remotecontrol.jpg", "calculator.jpg", "cellphone.jpg"],
|
31 |
+
allow_flagging="never",
|
32 |
+
outputs=["audio", "text"])
|
33 |
|
34 |
app.launch()
|