Capstone / app.py
ToletiSri's picture
Update app.py
473251f verified
raw history blame
No virus
4.14 kB
import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration
class _MLPVectorProjector(nn.Module):
def __init__(
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
):
super(_MLPVectorProjector, self).__init__()
self.mlps = nn.ModuleList()
for _ in range(width):
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
for _ in range(1, num_layers):
mlp.append(nn.GELU())
mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
self.mlps.append(nn.Sequential(*mlp))
def forward(self, x):
return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)
## Text model
model_name = "microsoft/phi-2"
phi2_text = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer_text = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
## Audio model
processor_audio = WhisperProcessor.from_pretrained("openai/whisper-small")
model_audio = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model_audio.config.forced_decoder_ids = None
## image model
def example_inference(input_text, count): #, image, img_qn, audio):
pred_text = textMode(input_text, count)
return pred_text #, "in progress", "in progress"
def textMode(text, count):
count = int(count)
inputs = tokenizer_text(text, return_tensors="pt", return_attention_mask=False)
prediction = tokenizer_text.batch_decode(
phi2_text.generate(
**inputs,
max_new_tokens=count,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
)
return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
def imageMode(image, question):
return "In progress"
def audioMode(audio):
#print('---------type of audio--------------')
#sampling_rate = audio[0]
audio_array = audio[1]
#print(sampling_rate)
#print(audio_array)
input_features = processor_audio(audio_array, sampling_rate=16000, return_tensors="pt").input_features
predicted_ids = model_audio.generate(input_features)
transcription = processor_audio.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription
interface_title = "TSAI-ERA-V1 - Capstone - Multimodal GPT Demo"
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(f"## **{interface_title}**")
gr.Markdown("Choose text mode/image mode/audio mode for generation")
with gr.Tab("Text mode"):
text_input = gr.Textbox(placeholder="Enter a prompt", label="Input")
text_input_count = gr.Textbox(placeholder="Enter number of characters you want to generate", label="Count")
text_button = gr.Button("Submit")
text_output = gr.Textbox(label="Chat GPT like text")
with gr.Tab("Image mode"):
with gr.Row():
image_input = gr.Image()
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
image_button = gr.Button("Submit")
image_text_output = gr.Textbox(label="Answer")
with gr.Tab("Audio mode"):
audio_input = gr.Audio()
audio_button = gr.Button("Submit")
audio_text_output = gr.Textbox(label="Chat GPT like text")
text_button.click(textMode, inputs=[text_input, text_input_count], outputs=text_output)
image_button.click(imageMode, inputs=[image_input,image_text_input], outputs=image_text_output)
audio_button.click(audioMode, inputs=audio_input, outputs=audio_text_output)
gr.Examples(
examples=[
["What is a large language model?","50"] #,"","Describe the image",""]
],
inputs=[text_input, text_input_count], #, image_input, image_text_input, audio_input],
outputs=[text_output], #, image_text_output, audio_text_output],
fn=example_inference,
)
demo.launch()