File size: 4,380 Bytes
804a590
 
 
 
 
 
 
 
 
 
 
073bbfa
 
804a590
 
073bbfa
804a590
 
073bbfa
 
 
 
 
804a590
 
 
 
 
 
 
 
073bbfa
14fb0e4
 
 
 
 
804a590
14fb0e4
 
 
804a590
 
14fb0e4
804a590
073bbfa
804a590
 
073bbfa
804a590
 
073bbfa
 
804a590
 
 
 
 
 
 
073bbfa
804a590
 
 
 
073bbfa
 
 
14fb0e4
073bbfa
 
804a590
 
073bbfa
 
 
 
 
804a590
9c4906e
073bbfa
804a590
14fb0e4
804a590
 
 
073bbfa
804a590
073bbfa
804a590
 
 
 
073bbfa
14fb0e4
 
 
 
 
 
 
0b93aab
 
 
 
804a590
 
 
 
14fb0e4
804a590
073bbfa
 
 
804a590
073bbfa
 
 
1a0adff
073bbfa
 
 
14fb0e4
 
804a590
 
 
 
1a0adff
073bbfa
 
 
 
 
b2e3fcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from openai import OpenAI
import gradio as gr
import requests
from PIL import Image
import numpy as np
import ipadic
import MeCab
import difflib
import io
import os

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


def generate_image(text):
    image_path = f"./{text}.png"
    if not os.path.exists(image_path):
        response = client.images.generate(
            model="dall-e-3",
            prompt=text,
            size="1024x1024",
            quality="standard",
            n=1,
        )
        image_url = response.data[0].url
        image_data = requests.get(image_url).content
        img = Image.open(io.BytesIO((image_data)))
        img = img.resize((512, 512))
        img.save(image_path)
    return image_path


def cos_sim(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))


def calculate_similarity_score(ori_text, text):
    if ori_text != text:
        response = client.embeddings.create(
            input=[ori_text, text], model="text-embedding-3-small"
        )
        score = cos_sim(response.data[0].embedding, response.data[1].embedding)
        score = int(round(score, 2) * 100)
        score = 99 if score == 100 else score
    else:
        score = 100
    return score


def tokenize_text(text):
    mecab = MeCab.Tagger(f"-Ochasen {ipadic.MECAB_ARGS}")
    return [t.split()[0] for t in mecab.parse(text).splitlines()[:-1]]


def create_match_words(ori_text, text):
    ori_words = tokenize_text(ori_text)
    words = tokenize_text(text)
    match_words = [w for w in words if w in ori_words]
    return match_words


def create_hint_text(ori_text, text):
    response = list(difflib.ndiff(list(text), list(ori_text)))
    output = ""
    for r in response:
        if r[:2] == "- ":
            continue
        elif r[:2] == "+ ":
            output += "X"
        else:
            output += r.strip()
    return output


def update_question(option):
    answer = os.getenv(option)
    return f"./{answer}.png"


def main(text, option):
    ori_text = os.getenv(option)
    image_path = generate_image(text)
    score = calculate_similarity_score(ori_text, text)

    if score < 80:
        match_words = create_match_words(ori_text, text)
        hint_text = "一致している単語リスト: " + " ".join(match_words)
    elif 80 <= score < 100:
        hint_text = "一致していない箇所: " + create_hint_text(ori_text, text)
    else:
        hint_text = ""
    return image_path, f"{score}点", hint_text


def auth(user_name, password):
    if user_name == os.getenv("USER_NAME") and password == os.getenv("PASSWORD"):
        return True
    else:
        return False


questions = ["Q1", "Q2", "Q3"]
for q in questions:
    image_path = generate_image(os.getenv(q))

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                "# プロンプトを当てるゲーム \n これは表示されている画像のプロンプトを当てるゲームです。プロンプトを入力するとそれに対応した画像とスコアとヒントが表示されます。スコア100点を目指して頑張ってください! \n\nヒントは80点未満の場合は当たっている単語(順番は合っているとは限らない)、80点以上の場合は足りない文字を「X」で示した文字列を表示しています。",
            )
            option = gr.components.Radio(
                ["Q1", "Q2", "Q3"], label="問題を選んでください!"
            )
            output_title_image = gr.components.Image(type="filepath", label="お題")
            option.change(
                update_question, inputs=[option], outputs=[output_title_image]
            )

            input_text = gr.components.Textbox(
                lines=1, label="画像にマッチするテキストを入力して!"
            )
            submit_button = gr.Button("Submit!")

        with gr.Column():
            output_image = gr.components.Image(type="filepath", label="生成画像")
            output_score = gr.components.Textbox(lines=1, label="スコア")
            output_hint_text = gr.components.Textbox(lines=1, label="ヒント")

    submit_button.click(
        main,
        inputs=[input_text, option],
        outputs=[output_image, output_score, output_hint_text],
    )
demo.launch(auth=auth)