DL4NLP / app.py
santanus24's picture
Update app.py
9966aac verified
import torch
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import requests
from PIL import Image
import requests
import gradio as gr
# Load translation model and tokenizer
translate_model_name = "facebook/mbart-large-50-many-to-many-mmt"
translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name)
tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name)
# load the base model in 4 bit quantized
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
)
# finetuned model adapter path (Hugging Face Hub)
model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model'
# merge the models
merged_model = LlavaForConditionalGeneration.from_pretrained(model_id,
quantization_config=quantization_config,
torch_dtype=torch.float16)
# create processor from base model
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# function to translate
def translate(text, source_lang, target_lang):
# Set source language
tokenizer.src_lang = source_lang
# Encode the text
encoded_text = tokenizer(text, return_tensors="pt")
# Force target language token
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
# Generate the translation
generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id)
# Decode the translation
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
# function for making inference
def ask_vlm(hindi_input_text, image):
# translate from Hindi to English
prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX")
prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:"
# If image is uploaded, open the image from bytes, else open from URL
if hasattr(image, 'read'):
image = Image.open(image)
else:
image = Image.open(requests.get(image, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
generate_ids = merged_model.generate(**inputs, max_new_tokens=250)
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
assistant_index = decoded_response.find("ASSISTANT:")
# Extract text after "ASSISTANT:"
if assistant_index != -1:
text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):]
# Remove leading and trailing whitespace
text_after_assistant = text_after_assistant.strip()
else:
text_after_assistant = None
hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN")
return hindi_output_text
# Define Gradio interface
input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)")
input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)")
output_text = gr.outputs.Textbox(label="Response (Hindi)")
# Create Gradio app
gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch()
if __name__ == '__main__':
image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg'
user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये'
output = ask_vlm(user_query, image_url)
print('Output:\n', output)