import torch from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor from transformers import MBartForConditionalGeneration, MBart50TokenizerFast import requests from PIL import Image import requests import gradio as gr # Load translation model and tokenizer translate_model_name = "facebook/mbart-large-50-many-to-many-mmt" translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name) tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name) # load the base model in 4 bit quantized quantization_config = BitsAndBytesConfig( load_in_4bit=True, ) # finetuned model adapter path (Hugging Face Hub) model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model' # merge the models merged_model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.float16) # create processor from base model processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") # function to translate def translate(text, source_lang, target_lang): # Set source language tokenizer.src_lang = source_lang # Encode the text encoded_text = tokenizer(text, return_tensors="pt") # Force target language token forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] # Generate the translation generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id) # Decode the translation translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return translation # function for making inference def ask_vlm(hindi_input_text, image): # translate from Hindi to English prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX") prompt = "USER: \n" + prompt_eng + " ASSISTANT:" # If image is uploaded, open the image from bytes, else open from URL if hasattr(image, 'read'): image = Image.open(image) else: image = Image.open(requests.get(image, stream=True).raw) inputs = processor(text=prompt, images=image, return_tensors="pt") generate_ids = merged_model.generate(**inputs, max_new_tokens=250) decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] assistant_index = decoded_response.find("ASSISTANT:") # Extract text after "ASSISTANT:" if assistant_index != -1: text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):] # Remove leading and trailing whitespace text_after_assistant = text_after_assistant.strip() else: text_after_assistant = None hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN") return hindi_output_text # Define Gradio interface input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)") input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)") output_text = gr.outputs.Textbox(label="Response (Hindi)") # Create Gradio app gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch() if __name__ == '__main__': image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg' user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये' output = ask_vlm(user_query, image_url) print('Output:\n', output)