ronniet commited on
Commit
ce16067
1 Parent(s): a2ceb9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -1,24 +1,28 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- # from TTS.api import TTS
4
 
5
  import librosa
6
  import numpy as np
7
  import torch
8
 
9
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
 
10
 
11
 
12
  checkpoint = "microsoft/speecht5_tts"
13
- processor = SpeechT5Processor.from_pretrained(checkpoint)
14
- model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
15
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
16
 
 
 
 
 
17
  def tts(text):
18
  if len(text.strip()) == 0:
19
  return (16000, np.zeros(0).astype(np.int16))
20
 
21
- inputs = processor(text=text, return_tensors="pt")
22
 
23
  # limit input length
24
  input_ids = inputs["input_ids"]
@@ -44,26 +48,36 @@ def tts(text):
44
 
45
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
46
 
47
- speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
48
 
49
  speech = (speech.numpy() * 32767).astype(np.int16)
50
  return (16000, speech)
51
 
52
 
53
- captioner = pipeline(model="microsoft/git-base")
54
  # tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
55
 
56
 
57
  def predict(image):
58
- text = captioner(image)[0]["generated_text"]
59
 
60
  # audio_output = "output.wav"
61
  # tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)
 
 
 
 
 
 
 
 
 
 
 
62
  audio = tts(text)
63
 
64
  return text, audio
65
 
66
- # theme = gr.themes.Default(primary_hue="#002A5B")
67
 
68
  demo = gr.Interface(
69
  fn=predict,
@@ -74,6 +88,3 @@ demo = gr.Interface(
74
  )
75
 
76
  demo.launch()
77
-
78
- # gr.Interface.load("models/ronniet/git-base-env").launch()
79
- # gr.Interface.load("models/microsoft/git-base").launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
  import librosa
5
  import numpy as np
6
  import torch
7
 
8
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
9
+ from transformers import AutoProcessor, AutoModelForCausalLM
10
 
11
 
12
  checkpoint = "microsoft/speecht5_tts"
13
+ tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
14
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
15
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
16
 
17
+
18
+ vqa_processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
19
+ vqa_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
20
+
21
  def tts(text):
22
  if len(text.strip()) == 0:
23
  return (16000, np.zeros(0).astype(np.int16))
24
 
25
+ inputs = tts_processor(text=text, return_tensors="pt")
26
 
27
  # limit input length
28
  input_ids = inputs["input_ids"]
 
48
 
49
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
50
 
51
+ speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
52
 
53
  speech = (speech.numpy() * 32767).astype(np.int16)
54
  return (16000, speech)
55
 
56
 
57
+ # captioner = pipeline(model="microsoft/git-base")
58
  # tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
59
 
60
 
61
  def predict(image):
62
+ # text = captioner(image)[0]["generated_text"]
63
 
64
  # audio_output = "output.wav"
65
  # tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)
66
+
67
+ pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
68
+
69
+ prompt = "what is in the scene?"
70
+ prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
71
+ prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
72
+ prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
73
+
74
+ text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
75
+ text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)
76
+
77
  audio = tts(text)
78
 
79
  return text, audio
80
 
 
81
 
82
  demo = gr.Interface(
83
  fn=predict,
 
88
  )
89
 
90
  demo.launch()