st0bb3n commited on
Commit
1559fe8
1 Parent(s): 8884faf

Update app.py

Browse files

added debug points

Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -3,24 +3,32 @@ import gradio as gr
3
  from datasets import load_dataset
4
  import torch
5
 
 
6
  dataset = load_dataset("cifar100")
7
  image = dataset["train"]["fine_label"]
 
8
 
9
  def classify(image):
10
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
 
11
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
 
12
  inputs = feature_extractor(images=image, return_tensors="pt")
 
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
  logits = outputs.logits
16
  # model predicts one of the 1000 ImageNet classes
 
17
  predicted_class_idx = logits.argmax(-1).item()
18
  return model.config.id2label[predicted_class_idx]
19
 
20
  def image2speech(image):
 
21
  txt = classify(image)
22
  return fastspeech(txt), txt
23
-
 
24
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
25
 
26
  '''
@@ -35,15 +43,16 @@ app = gr.Interface(fn=image2speech,
35
 
36
  app.launch(cache_examples=True)
37
  '''
38
-
39
  camera = gr.inputs.Image(label="Image from your camera", source="webcam")
40
  read = gr.outputs.Textbox(type="auto", label="Text")
41
  speak = gr.outputs.Audio(type="auto", label="Speech")
42
 
 
43
  app = gr.Interface(fn=image2speech,
44
  inputs=camera,
45
  live=True,
46
  description="Takes a snapshot of an object, identifies it, and then tell you what it is. \n Intended use is to help the visually impaired. Models and dataset used is listed on the linked models and dataset",
47
  outputs=[speak, read])
48
-
49
  app.launch()
 
3
  from datasets import load_dataset
4
  import torch
5
 
6
+
7
  dataset = load_dataset("cifar100")
8
  image = dataset["train"]["fine_label"]
9
+ print("load and train dataset \n")
10
 
11
  def classify(image):
12
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
13
+ print("feature extractor \n")
14
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
15
+ print("load model \n")
16
  inputs = feature_extractor(images=image, return_tensors="pt")
17
+ print("define input \n")
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
  logits = outputs.logits
21
  # model predicts one of the 1000 ImageNet classes
22
+ print("prediction \n")
23
  predicted_class_idx = logits.argmax(-1).item()
24
  return model.config.id2label[predicted_class_idx]
25
 
26
  def image2speech(image):
27
+ print("tts \n")
28
  txt = classify(image)
29
  return fastspeech(txt), txt
30
+
31
+ print("load tts interface \n")
32
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
33
 
34
  '''
 
43
 
44
  app.launch(cache_examples=True)
45
  '''
46
+ print("sets input and outputs \n")
47
  camera = gr.inputs.Image(label="Image from your camera", source="webcam")
48
  read = gr.outputs.Textbox(type="auto", label="Text")
49
  speak = gr.outputs.Audio(type="auto", label="Speech")
50
 
51
+ print("define interface \n")
52
  app = gr.Interface(fn=image2speech,
53
  inputs=camera,
54
  live=True,
55
  description="Takes a snapshot of an object, identifies it, and then tell you what it is. \n Intended use is to help the visually impaired. Models and dataset used is listed on the linked models and dataset",
56
  outputs=[speak, read])
57
+ print("launch interface \n")
58
  app.launch()