Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor | |
| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
| import requests | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| # Load translation model and tokenizer | |
| translate_model_name = "facebook/mbart-large-50-many-to-many-mmt" | |
| translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name) | |
| tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name) | |
| # load the base model in 4 bit quantized | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| ) | |
| # finetuned model adapter path (Hugging Face Hub) | |
| model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model' | |
| # merge the models | |
| merged_model = LlavaForConditionalGeneration.from_pretrained(model_id, | |
| quantization_config=quantization_config, | |
| torch_dtype=torch.float16) | |
| # create processor from base model | |
| processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
| # function to translate | |
| def translate(text, source_lang, target_lang): | |
| # Set source language | |
| tokenizer.src_lang = source_lang | |
| # Encode the text | |
| encoded_text = tokenizer(text, return_tensors="pt") | |
| # Force target language token | |
| forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] | |
| # Generate the translation | |
| generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id) | |
| # Decode the translation | |
| translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return translation | |
| # function for making inference | |
| def ask_vlm(hindi_input_text, image): | |
| # translate from Hindi to English | |
| prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX") | |
| prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:" | |
| # If image is uploaded, open the image from bytes, else open from URL | |
| if hasattr(image, 'read'): | |
| image = Image.open(image) | |
| else: | |
| image = Image.open(requests.get(image, stream=True).raw) | |
| inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| generate_ids = merged_model.generate(**inputs, max_new_tokens=250) | |
| decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| assistant_index = decoded_response.find("ASSISTANT:") | |
| # Extract text after "ASSISTANT:" | |
| if assistant_index != -1: | |
| text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):] | |
| # Remove leading and trailing whitespace | |
| text_after_assistant = text_after_assistant.strip() | |
| else: | |
| text_after_assistant = None | |
| hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN") | |
| return hindi_output_text | |
| # Define Gradio interface | |
| input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)") | |
| input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)") | |
| output_text = gr.outputs.Textbox(label="Response (Hindi)") | |
| # Create Gradio app | |
| gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch() | |
| if __name__ == '__main__': | |
| image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg' | |
| user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये' | |
| output = ask_vlm(user_query, image_url) | |
| print('Output:\n', output) |