import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM, pipeline import torch device = "cuda" if torch.cuda.is_available() else "cpu" # Inference with the pre-trained model checkpoint1 = "microsoft/git-base" processor = AutoProcessor.from_pretrained(checkpoint1) #model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token) model1 = AutoModelForCausalLM.from_pretrained(checkpoint1) # Inference with the fine-tuned model checkpoint2 = "wangjin2000/git-base-finetune" #model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token) model2 = AutoModelForCausalLM.from_pretrained(checkpoint2) en_zh_translator = pipeline("translation",src_lang = "en", tgt_lang = "zh") def img2cap_com(image): input1 = processor(images=image, return_tensors="pt").to(device) pixel_values1 = input1.pixel_values generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50) generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0] #translated_caption1 = en_zh_translator(generated_caption1) translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)] input2 = processor(images=image, return_tensors="pt").to(device) pixel_values2 = input2.pixel_values generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50) generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0] translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)] return translated_caption1,translated_caption2 inputs = [ gr.inputs.Image(type="pil", label="Original Image") ] outputs = [ gr.outputs.Textbox(label="Caption from pre-trained model"), gr.outputs.Textbox(label="Caption from fine-tuned model"), ] title = "Image Captioning using Pre-trained and Fine-tuned Model" description = "GIT-base is used to generate Image Caption for the uploaded image." examples = [ ["Image1.png"], ["Image2.png"], ["Image3.png"], ["Image4.png"], ["Image5.png"], ["Image6.png"] ] gr.Interface( img2cap_com, inputs, outputs, title=title, description=description, examples=examples, theme="huggingface", ).launch()