File size: 3,902 Bytes
17c016d
 
 
 
 
 
 
 
a3a6018
17c016d
 
 
 
 
 
 
 
 
 
 
 
68e19a9
17c016d
a3a6018
 
 
 
 
 
17c016d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59d9429
17c016d
59d9429
17c016d
 
 
 
 
 
6abafb0
 
59d9429
 
 
6abafb0
 
 
 
59d9429
6abafb0
 
 
17c016d
 
 
6abafb0
 
 
 
 
 
 
 
 
17c016d
6abafb0
 
 
 
17c016d
 
 
 
 
 
 
a3a6018
17c016d
 
 
 
6abafb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import pipeline
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
import os
import random


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "openai/clip-vit-base-patch16"  # You can choose a different CLIP model from Hugging Face
clipprocessor = CLIPProcessor.from_pretrained(model_id)
clipmodel = CLIPModel.from_pretrained(model_id).to(device)


model_id = "Salesforce/blip-image-captioning-base" ## load modelID for BLIP
blipmodel = BlipForConditionalGeneration.from_pretrained(model_id)
blipprocessor = BlipProcessor.from_pretrained(model_id)

im_dir = os.path.join(os.getcwd(),'images')

def sample_image(im_dir=im_dir):
  all_ims = os.listdir(im_dir)
  new_im = random.choice(all_ims)
  return gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,new_im),height=500),gr.Textbox(label="Image fname",value=new_im,interactive=False, visible=False)


def evaluate_caption(image, caption):
    # # Pre-process image
    # image = processor(images=image, return_tensors="pt").to(device)

    # # Tokenize and encode the caption
    # text = processor(text=caption, return_tensors="pt").to(device)



    blip_input = blipprocessor(image, return_tensors="pt")
    out = blipmodel.generate(**blip_input,max_new_tokens=50)
    blip_caption = blipprocessor.decode(out[0], skip_special_tokens=True)

    inputs = clipprocessor(text=[caption,blip_caption], images=image, return_tensors="pt", padding=True)

    similarity_score = clipmodel(**inputs).logits_per_image



    # Convert score to a float
    score = similarity_score.softmax(dim=1).detach().numpy()
    print(score)
    if score[0][0]>score[0][1]:
      winner = "Player 1 wins!"
    else:
      winner = "Player 2 wins!"

    
    return blip_caption,winner
    # ,gr.Image(type="pil", value="mukherjee_kushin_WIDPICS1.jpg")

callback = gr.HuggingFaceDatasetSaver('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', "gradioTest")

with gr.Blocks(theme=gr.themes.Soft()) as demo:
  gr.Markdown(
    """
    # Welcome to our Human vs. AI game!

   You and an AI agent are trying to convince a third AI agent that each of you are better at describing the visual world. \n
   In order to win, describe this image in one sentence. Then the second AI agent will also generate a description and the third agent will decide a winner.
   You win if the AI says that "Player 1 wins!"
    """)
  # im_path_str = 'n03418158_2886.JPEG'
  im_path_str = random.choice(os.listdir(im_dir))
  
  im_path = gr.Textbox(label="Image fname",value=im_path_str,interactive=False, visible=False)
  # fn=evaluate_caption,
  # inputs=["image", "text"]


  with gr.Row():
    im = gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,im_path_str),height=400)
    with gr.Column():
      caps = gr.Textbox(label="Player 1 Caption")
      submit_btn = gr.Button("Submit!!")
      out1 = gr.Textbox(label="Player 2 (Machine) Caption",interactive=False)
    
  # outputs=["text","text"],
  with gr.Row():
    with gr.Column():
      out2 = gr.Textbox(label="Winner",interactive=False)
      reload_btn = gr.Button("Next Image")


  # live=False,
  # interpretation="default"
  callback.setup([caps, out1, out2, im_path], "flagged_data_points")
  # callback.flag([image, caption, blip_caption, winner])
  submit_btn.click(fn = evaluate_caption,inputs = [im,caps], outputs = [out1, out2],api_name="test").success(lambda *args: callback.flag(args), [caps, out1, out2, im_path], None, preprocess=False)
  reload_btn.click(fn = sample_image, inputs=None, outputs = [im,im_path] )
  # with gr.Row():
  #     btn = gr.Button("Flag")
  # btn.click(lambda *args: callback.flag(args), [im, caps, out1, out2], None, preprocess=False)

demo.launch()