st0bb3n commited on
Commit
198fce8
1 Parent(s): d6912cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,12 +1,18 @@
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  import gradio as gr
 
 
 
 
 
3
 
4
  def classify(image):
5
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
6
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
7
  inputs = feature_extractor(images=image, return_tensors="pt")
8
- outputs = model(**inputs)
9
- logits = outputs.logits
 
10
  # model predicts one of the 1000 ImageNet classes
11
  predicted_class_idx = logits.argmax(-1).item()
12
  return model.config.id2label[predicted_class_idx]
@@ -17,6 +23,12 @@ def image2speech(image):
17
 
18
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
19
 
20
- app = gr.Interface(fn=image2speech, inputs="image", outputs=["audio", "text"])
 
 
 
 
 
 
21
 
22
  app.launch()
 
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  import gradio as gr
3
+ from datasets import load_dataset
4
+ import torch
5
+
6
+ dataset = load_dataset("cifar100")
7
+ image = dataset["train"]["fine_label"]
8
 
9
  def classify(image):
10
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
11
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
12
  inputs = feature_extractor(images=image, return_tensors="pt")
13
+ with torch.no_grad():
14
+ outputs = model(**inputs)
15
+ logits = outputs.logits
16
  # model predicts one of the 1000 ImageNet classes
17
  predicted_class_idx = logits.argmax(-1).item()
18
  return model.config.id2label[predicted_class_idx]
 
23
 
24
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
25
 
26
+ app = gr.Interface(fn=image2speech,
27
+ inputs="image",
28
+ title="Image to speech",
29
+ description="Classifies and image and tell you what is it",
30
+ examples=["remotecontrol.jpg", "calculator.jpg", "cellphone.jpg"],
31
+ allow_flagging="never",
32
+ outputs=["audio", "text"])
33
 
34
  app.launch()