File size: 1,980 Bytes
ec4a322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from secrets_key import OPENAI_KEY, RANDOM_SEED
from openai import OpenAI
import json
import pandas as pd
from pprint import pprint


client = OpenAI(api_key=OPENAI_KEY)


prompt = """

You are given a story and 3 images related to the story. Identify a person/object that can be visually identified in the images but not directly mentioned on the story. Use as few words as possible to describe each person/object. Also, mention the image number (1, 2 or 3) where the person/object can be found.
Output in a python list of dictionaries. Each dictionary should have the following keys: 'image_number', 'person/object'.


Story: {story}
"""


def get_entity_gpt4V(row):
    story = row['Input.story']
    now_prompt = prompt.format(story=story)
    content = [
        {"type": "text", "text": now_prompt},
    ]
    images = []
    for i in range(1,4):
        image_url = row[f'Input.image{i}']
        images.append(image_url)
        content.append({
            "type": "image_url",
            "image_url": {
                "url": image_url,
            },
        })
    
    response = client.chat.completions.create(
        model="gpt-4-vision-preview",
        seed=RANDOM_SEED,
        messages=[
            {
                "role": "user",
                "content": content
            }
        ],
        temperature=1,
        max_tokens=256,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
    )

    print(row['HITId'])
    print(now_prompt)
    pprint(images)
    out = response.choices[0].message.content
    print("OUTPUT:", out)
    print("====================================")
    print()

if __name__ == '__main__':
    df = pd.read_csv('./results.csv')

    count = 0
    done = set()
    for ind, row in df.iterrows():
        item_id = row['Input.item_id']
        if item_id in done:
            continue
        done.add(item_id)
        get_entity_gpt4V(row)
        count += 1

        if count == 10:
            break