File size: 2,473 Bytes
000c2c2
 
04b62bf
 
 
 
 
a2a6f2c
ce16067
04b62bf
 
a2a6f2c
 
 
 
04b62bf
ce16067
b3502f7
 
ce16067
a2a6f2c
 
 
04b62bf
a2a6f2c
04b62bf
a2a6f2c
 
 
04b62bf
a2a6f2c
04b62bf
a2a6f2c
04b62bf
a2a6f2c
04b62bf
a2a6f2c
 
04b62bf
cca9a4c
ce16067
04b62bf
18d8458
1c09801
5bad71b
ce16067
 
 
5bad71b
ce16067
 
 
 
 
0349c26
ce16067
a2a6f2c
1c09801
a2a6f2c
18d8458
43439da
000c2c2
 
3238595
a2a6f2c
9ac034b
a4204a9
 
 
 
000c2c2
18d8458
000c2c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import pipeline

import librosa
import numpy as np
import torch

# from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM


# checkpoint = "microsoft/speecht5_tts"
# tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
# tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
# vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")


vqa_processor = AutoProcessor.from_pretrained("ronniet/git-large-vqa-env")
vqa_model = AutoModelForCausalLM.from_pretrained("ronniet/git-large-vqa-env")

# def tts(text):
#     if len(text.strip()) == 0:
#         return (16000, np.zeros(0).astype(np.int16))

#     inputs = tts_processor(text=text, return_tensors="pt")

#     # limit input length
#     input_ids = inputs["input_ids"]
#     input_ids = input_ids[..., :tts_model.config.max_text_positions]

#     speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")

#     speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)

#     speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)

#     speech = (speech.numpy() * 32767).astype(np.int16)
#     return (16000, speech)


# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)


def predict(image, prompt):

    pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values

    # prompt = "what is in the scene?"
    prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
    prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
    prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
    
    text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
    text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)[0][len(prompt):]
    
    # audio = tts(text)
    
    return text


demo = gr.Interface(
    fn=predict,
    inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")],
    outputs=gr.Textbox(label="Caption"),
    css=".gradio-container {background-color: #002A5B}",
    theme=gr.themes.Soft()  #.set(
    #     button_primary_background_fill="#AAAAAA",
    #     button_primary_border="*button_primary_background_fill_dark"
    # )
)

demo.launch()