UPDATE ARENA
Browse files
app.py
CHANGED
@@ -1,19 +1,45 @@
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
|
|
|
|
|
|
|
8 |
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
message,
|
12 |
history: list[tuple[str, str]],
|
13 |
system_message,
|
14 |
max_tokens,
|
15 |
temperature,
|
16 |
top_p,
|
|
|
17 |
):
|
18 |
messages = [{"role": "system", "content": system_message}]
|
19 |
|
@@ -27,38 +53,110 @@ def respond(
|
|
27 |
|
28 |
response = ""
|
29 |
|
30 |
-
for message in
|
31 |
messages,
|
32 |
max_tokens=max_tokens,
|
33 |
stream=True,
|
34 |
temperature=temperature,
|
35 |
top_p=top_p,
|
36 |
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
response += token
|
40 |
yield response
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
)
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
+
from collections import defaultdict, Counter
|
4 |
+
import random
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
import os
|
8 |
|
9 |
+
# 比較するLLMのリストを定義
|
10 |
+
llm_list = [
|
11 |
+
"HuggingFaceH4/zephyr-7b-beta",
|
12 |
+
"AXCXEPT/EZO-Common-9B-gemma-2-it",
|
13 |
+
# "モデル名1",
|
14 |
+
# "モデル名2",
|
15 |
+
]
|
16 |
|
17 |
+
# 各LLMの出現回数を管理
|
18 |
+
llm_counts = defaultdict(int)
|
19 |
|
20 |
+
# 各LLMのInferenceClientを作成
|
21 |
+
clients = {llm: InferenceClient(llm) for llm in llm_list}
|
22 |
+
|
23 |
+
# LLMを等しくランダムに選択する関数
|
24 |
+
def select_llms():
|
25 |
+
min_count = min(llm_counts.values()) if llm_counts else 0
|
26 |
+
candidates = [llm for llm in llm_list if llm_counts[llm] == min_count]
|
27 |
+
if len(candidates) < 2:
|
28 |
+
candidates = llm_list
|
29 |
+
selected_llms = random.sample(candidates, 2)
|
30 |
+
for llm in selected_llms:
|
31 |
+
llm_counts[llm] += 1
|
32 |
+
return selected_llms
|
33 |
+
|
34 |
+
# 各LLMに対する応答を生成する関数
|
35 |
+
def respond_llm(
|
36 |
message,
|
37 |
history: list[tuple[str, str]],
|
38 |
system_message,
|
39 |
max_tokens,
|
40 |
temperature,
|
41 |
top_p,
|
42 |
+
llm_client,
|
43 |
):
|
44 |
messages = [{"role": "system", "content": system_message}]
|
45 |
|
|
|
53 |
|
54 |
response = ""
|
55 |
|
56 |
+
for message in llm_client.chat_completion(
|
57 |
messages,
|
58 |
max_tokens=max_tokens,
|
59 |
stream=True,
|
60 |
temperature=temperature,
|
61 |
top_p=top_p,
|
62 |
):
|
63 |
+
token = message.choices[0].delta.get("content", "")
|
|
|
64 |
response += token
|
65 |
yield response
|
66 |
|
67 |
+
# 投票結果を保存するファイルのパス
|
68 |
+
VOTE_FILE = "votes.txt"
|
69 |
+
|
70 |
+
# 投票結果を保存する関数
|
71 |
+
def save_vote(selected_llm):
|
72 |
+
# 投票結果をファイルに保存
|
73 |
+
with open(VOTE_FILE, "a") as f:
|
74 |
+
f.write(f"{selected_llm}\n")
|
75 |
+
return gr.update(visible=True, value="投票ありがとうございました!")
|
76 |
+
|
77 |
+
# リーダーボードを更新する関数
|
78 |
+
def update_leaderboard():
|
79 |
+
try:
|
80 |
+
with open(VOTE_FILE, "r") as f:
|
81 |
+
votes = f.readlines()
|
82 |
+
vote_counts = Counter(vote.strip() for vote in votes)
|
83 |
+
leaderboard = sorted(vote_counts.items(), key=lambda x: x[1], reverse=True)
|
84 |
+
leaderboard_text = "## リーダーボード\n\n"
|
85 |
+
for llm, count in leaderboard:
|
86 |
+
leaderboard_text += f"- {llm}: {count}票\n"
|
87 |
+
except FileNotFoundError:
|
88 |
+
leaderboard_text = "まだ投票がありません。"
|
89 |
+
return leaderboard_text
|
90 |
+
|
91 |
+
# Gradioインターフェースの構築
|
92 |
+
def chat_interface():
|
93 |
+
llm1, llm2 = select_llms()
|
94 |
+
client1 = clients[llm1]
|
95 |
+
client2 = clients[llm2]
|
96 |
+
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
gr.Markdown("## LLM比較アリーナ")
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
gr.Markdown(f"### LLM1: {llm1}")
|
102 |
+
gr.Markdown(f"### LLM2: {llm2}")
|
103 |
|
104 |
+
with gr.Row():
|
105 |
+
with gr.Column():
|
106 |
+
chat1 = gr.ChatInterface(
|
107 |
+
lambda message, history, system_message, max_tokens, temperature, top_p:
|
108 |
+
respond_llm(message, history, system_message, max_tokens, temperature, top_p, client1),
|
109 |
+
additional_inputs=[
|
110 |
+
gr.Textbox(value="あなたはフレンドリーなチャットボットです。", label="システムメッセージ"),
|
111 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大トークン数"),
|
112 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"),
|
113 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="トップP"),
|
114 |
+
],
|
115 |
+
)
|
116 |
+
with gr.Column():
|
117 |
+
chat2 = gr.ChatInterface(
|
118 |
+
lambda message, history, system_message, max_tokens, temperature, top_p:
|
119 |
+
respond_llm(message, history, system_message, max_tokens, temperature, top_p, client2),
|
120 |
+
additional_inputs=[
|
121 |
+
gr.Textbox(value="あなたはフレンドリーなチ��ットボットです。", label="システムメッセージ"),
|
122 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大トークン数"),
|
123 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"),
|
124 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="トップP"),
|
125 |
+
],
|
126 |
+
)
|
127 |
|
128 |
+
# 投票セクション
|
129 |
+
with gr.Row():
|
130 |
+
vote = gr.Radio([llm1, llm2], label="どちらの応答が良かったですか?")
|
131 |
+
submit = gr.Button("投票")
|
132 |
+
result = gr.Textbox(label="", visible=False)
|
133 |
+
|
134 |
+
submit.click(save_vote, inputs=vote, outputs=result)
|
135 |
+
|
136 |
+
# リーダーボードの表示
|
137 |
+
leaderboard = gr.Markdown(update_leaderboard())
|
138 |
+
|
139 |
+
return demo
|
140 |
+
|
141 |
+
# リーダーボードを定期的に更新するスレッド
|
142 |
+
def refresh_leaderboard(leaderboard_component):
|
143 |
+
while True:
|
144 |
+
leaderboard_text = update_leaderboard()
|
145 |
+
leaderboard_component.value = leaderboard_text
|
146 |
+
time.sleep(60) # 60秒ごとに更新
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
+
demo = chat_interface()
|
150 |
+
|
151 |
+
# リーダーボードコンポーネントを取得
|
152 |
+
leaderboard_component = None
|
153 |
+
for component in demo.blocks:
|
154 |
+
if isinstance(component, gr.Markdown) and "リーダーボード" in component.value:
|
155 |
+
leaderboard_component = component
|
156 |
+
break
|
157 |
+
|
158 |
+
# リーダーボード更新スレッドの開始
|
159 |
+
if leaderboard_component:
|
160 |
+
threading.Thread(target=refresh_leaderboard, args=(leaderboard_component,), daemon=True).start()
|
161 |
+
|
162 |
demo.launch()
|