EVad commited on
Commit
3c3c4fa
·
1 Parent(s): a21e4ca

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -2,7 +2,7 @@ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTok
2
  import torch
3
  from PIL import Image
4
 
5
- import gradio as gr
6
 
7
  from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
8
  from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
@@ -12,13 +12,16 @@ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-ima
12
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model.to(device)
16
 
17
  models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
18
  "facebook/fastspeech2-en-ljspeech",
19
- arg_overrides={"vocoder": "hifigan", "fp16": False}
20
  )
 
21
  model1 = models[0]
 
 
22
  TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
23
  generator = task.build_generator(models, cfg)
24
 
@@ -27,32 +30,27 @@ num_beams = 4
27
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
28
 
29
 
30
- def predict_step(image_paths):
31
- images = []
32
- text = ""
33
 
34
- for image_path in image_paths:
35
- i_image = Image.fromarray(image_path)
36
- if i_image.mode != "RGB":
37
  i_image = i_image.convert(mode="RGB")
38
- print(image_path)
39
 
40
- images.append(i_image)
41
- print(images)
42
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
43
  pixel_values = pixel_values.to(device)
44
 
45
  output_ids = model.generate(pixel_values, **gen_kwargs)
46
 
47
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
48
  preds = [pred.strip() for pred in preds]
49
- preds = ' '.join(str(e) for e in preds)
50
- text = text + preds
51
- sample = TTSHubInterface.get_model_input(task, text)
 
52
  wav, rate = TTSHubInterface.get_prediction(task, model1, generator, sample)
53
- return wav#, rate, text
54
- #return ipd.Audio(wav, rate=rate)
55
 
56
 
57
- interface = gr.Interface(predict_step, gr.Image(), "audio")
58
  interface.launch()
 
2
  import torch
3
  from PIL import Image
4
 
5
+ import gradio as gr
6
 
7
  from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
8
  from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
 
12
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = model.to(device)
16
 
17
  models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
18
  "facebook/fastspeech2-en-ljspeech",
19
+ arg_overrides={"vocoder": "hifigan", "fp16": True}
20
  )
21
+
22
  model1 = models[0]
23
+ model1 = model1.to(device)
24
+
25
  TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
26
  generator = task.build_generator(models, cfg)
27
 
 
30
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
31
 
32
 
33
+ def inference(image_paths):
 
 
34
 
35
+ #for image_path in image_paths:
36
+ i_image = Image.fromarray(image_paths)
37
+ if i_image.mode != "RGB":
38
  i_image = i_image.convert(mode="RGB")
 
39
 
40
+ pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values
 
 
41
  pixel_values = pixel_values.to(device)
42
 
43
  output_ids = model.generate(pixel_values, **gen_kwargs)
44
 
45
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
46
  preds = [pred.strip() for pred in preds]
47
+
48
+
49
+ sample = TTSHubInterface.get_model_input(task, preds)
50
+
51
  wav, rate = TTSHubInterface.get_prediction(task, model1, generator, sample)
52
+ return wav
 
53
 
54
 
55
+ interface = gr.Interface(inference, gr.Image(), "audio")
56
  interface.launch()