import string import gradio as gr import requests import torch from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline from PIL import Image model_name="hfl/vle-base-for-vqa" model = VLEForVQA.from_pretrained(model_name) vle_processor = VLEProcessor.from_pretrained(model_name) vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor) from transformers import BlipForQuestionAnswering, BlipProcessor device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device) from transformers import BlipProcessor, BlipForConditionalGeneration cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") def caption(input_image): inputs = cap_processor(input_image, return_tensors="pt") inputs["num_beams"] = 1 inputs['num_return_sequences'] =1 out = cap_model.generate(**inputs) return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True)) import openai import os openai.api_key= os.getenv('openai_appkey') def gpt3(question,vqa_answer,caption): prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer." response = openai.Completion.create( engine="text-davinci-003", prompt=prompt, max_tokens=10, n=1, stop=None, temperature=0.7, ) answer = response.choices[0].text.strip() # return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer return answer def vle(input_image,input_text): vqa_answers = vqa_pipeline(image=input_image, question=input_image, top_k=4) return vqa_answers def inference_chat(input_image,input_text): cap=caption(input_image) # inputs = processor(images=input_image, text=input_text,return_tensors="pt") # inputs["max_length"] = 10 # inputs["num_beams"] = 5 # inputs['num_return_sequences'] =4 # out = model_vqa.generate(**inputs) # out=processor.batch_decode(out, skip_special_tokens=True) out=vle(input_image,input_text) vqa="\n".join(out) gpt3_out=gpt3(input_text,vqa,cap) gpt3_out1=gpt3(input_text,'',cap) return out[0], gpt3_out,gpt3_out1 with gr.Blocks( css=""" .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} #component-21 > div.wrap.svelte-w6rprc {height: 600px;} """ ) as iface: state = gr.State([]) #caption_output = None #gr.Markdown(title) #gr.Markdown(description) #gr.Markdown(article) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil",label="VQA Image Input") with gr.Row(): with gr.Column(scale=1): chat_input = gr.Textbox(lines=1, label="VQA Quesiton Input") with gr.Row(): clear_button = gr.Button(value="Clear", interactive=True) submit_button = gr.Button( value="VQA", interactive=True, variant="primary" ) ''' cap_submit_button = gr.Button( value="Submit_CAP", interactive=True, variant="primary" ) gpt3_submit_button = gr.Button( value="Submit_GPT3", interactive=True, variant="primary" ) ''' with gr.Column(): caption_output_v1 = gr.Textbox(lines=0, label="CAP+LLM") caption_output = gr.Textbox(lines=0, label="VQA ") gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM") image_input.change( lambda: ("", [],"","",""), [], [ caption_output, state,caption_output,gpt3_output_v1,caption_output_v1], queue=False, ) chat_input.submit( inference_chat, [ image_input, chat_input, ], [ caption_output], ) clear_button.click( lambda: ("", [],"","",""), [], [chat_input, state,caption_output,gpt3_output_v1,caption_output_v1], queue=False, ) submit_button.click( inference_chat, [ image_input, chat_input, ], [caption_output,gpt3_output_v1,caption_output_v1], ) ''' cap_submit_button.click( caption, [ image_input, ], [caption_output_v1], ) gpt3_submit_button.click( gpt3, [ chat_input, caption_output , caption_output_v1, ], [gpt3_output_v1], ) ''' examples=[['bird.jpeg',"How many birds are there in the tree?"]] examples = gr.Examples( examples=examples, inputs=[image_input, chat_input], ) iface.queue(concurrency_count=1, api_open=False, max_size=10) iface.launch(enable_queue=True)