Spaces:
Runtime error
Runtime error
import torch | |
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
import requests | |
from PIL import Image | |
import requests | |
import gradio as gr | |
# Load translation model and tokenizer | |
translate_model_name = "facebook/mbart-large-50-many-to-many-mmt" | |
translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name) | |
tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name) | |
# load the base model in 4 bit quantized | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
) | |
# finetuned model adapter path (Hugging Face Hub) | |
model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model' | |
# merge the models | |
merged_model = LlavaForConditionalGeneration.from_pretrained(model_id, | |
quantization_config=quantization_config, | |
torch_dtype=torch.float16) | |
# create processor from base model | |
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
# function to translate | |
def translate(text, source_lang, target_lang): | |
# Set source language | |
tokenizer.src_lang = source_lang | |
# Encode the text | |
encoded_text = tokenizer(text, return_tensors="pt") | |
# Force target language token | |
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] | |
# Generate the translation | |
generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id) | |
# Decode the translation | |
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
return translation | |
# function for making inference | |
def ask_vlm(hindi_input_text, image): | |
# translate from Hindi to English | |
prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX") | |
prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:" | |
# If image is uploaded, open the image from bytes, else open from URL | |
if hasattr(image, 'read'): | |
image = Image.open(image) | |
else: | |
image = Image.open(requests.get(image, stream=True).raw) | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
generate_ids = merged_model.generate(**inputs, max_new_tokens=250) | |
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
assistant_index = decoded_response.find("ASSISTANT:") | |
# Extract text after "ASSISTANT:" | |
if assistant_index != -1: | |
text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):] | |
# Remove leading and trailing whitespace | |
text_after_assistant = text_after_assistant.strip() | |
else: | |
text_after_assistant = None | |
hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN") | |
return hindi_output_text | |
# Define Gradio interface | |
input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)") | |
input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)") | |
output_text = gr.outputs.Textbox(label="Response (Hindi)") | |
# Create Gradio app | |
gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch() | |
if __name__ == '__main__': | |
image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg' | |
user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये' | |
output = ask_vlm(user_query, image_url) | |
print('Output:\n', output) |