Commit
·
88a6c27
1
Parent(s):
678a3bb
Update preference_technique
Browse files- app.py +2 -2
- chat_interface_preference.py +13 -16
app.py
CHANGED
@@ -70,10 +70,10 @@ def generate(
|
|
70 |
|
71 |
chat_interface = ChatInterface(
|
72 |
fn=generate,
|
73 |
-
|
74 |
min_turns=1,
|
75 |
max_turns=10,
|
76 |
-
repo_id="llm-human-feedback-collector-chat-interface-
|
77 |
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
|
78 |
cache_examples=False,
|
79 |
additional_inputs=[
|
|
|
70 |
|
71 |
chat_interface = ChatInterface(
|
72 |
fn=generate,
|
73 |
+
prefence_technique="kto",
|
74 |
min_turns=1,
|
75 |
max_turns=10,
|
76 |
+
repo_id="llm-human-feedback-collector-chat-interface-kto",
|
77 |
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
|
78 |
cache_examples=False,
|
79 |
additional_inputs=[
|
chat_interface_preference.py
CHANGED
@@ -11,7 +11,7 @@ import json
|
|
11 |
import random
|
12 |
import re
|
13 |
import uuid
|
14 |
-
from typing import AsyncGenerator, Callable,
|
15 |
|
16 |
import anyio
|
17 |
from gradio.blocks import Blocks
|
@@ -27,8 +27,9 @@ from gradio.components import (
|
|
27 |
get_component_instance,
|
28 |
)
|
29 |
from gradio.events import Dependency, on
|
30 |
-
from gradio.helpers import Error, Info
|
31 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
|
|
32 |
from gradio.layouts import Accordion, Group, Row
|
33 |
from gradio.routes import Request
|
34 |
from gradio.themes import ThemeClass as Theme
|
@@ -65,7 +66,7 @@ class ChatInterface(Blocks):
|
|
65 |
self,
|
66 |
fn: Callable,
|
67 |
*,
|
68 |
-
|
69 |
min_turns: int = 1,
|
70 |
max_turns: int = 1,
|
71 |
repo_id: None | str,
|
@@ -126,14 +127,9 @@ class ChatInterface(Blocks):
|
|
126 |
raise ValueError("`max_turns` should be larger than `min_turns`")
|
127 |
self.max_turns = max_turns
|
128 |
self.min_turns = min_turns
|
129 |
-
if isinstance(prefence_techniques, str):
|
130 |
-
prefence_techniques = [prefence_techniques]
|
131 |
-
elif prefence_techniques is None:
|
132 |
-
prefence_techniques = ["sft"]
|
133 |
-
self.prefence_techniques = [technique.lower() for technique in prefence_techniques]
|
134 |
|
135 |
optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
|
136 |
-
if
|
137 |
raise ValueError(f"Supported techniques are {optional_techniques}")
|
138 |
submit_btn_one = "Generate"
|
139 |
submit_btn_two = None
|
@@ -145,11 +141,12 @@ class ChatInterface(Blocks):
|
|
145 |
stop_btn = "Stop"
|
146 |
undo_btn = "↩️ Undo"
|
147 |
clear_btn = "🗑️ Clear"
|
148 |
-
|
|
|
149 |
submit_btn_good = "The response 👍"
|
150 |
submit_btn_bad = "The response 👎"
|
151 |
-
if
|
152 |
-
|
153 |
submit_btn_a = "A is better than B"
|
154 |
submit_btn_b = "B is better than A"
|
155 |
submit_btn_ab = "A and B are similar"
|
@@ -367,7 +364,7 @@ class ChatInterface(Blocks):
|
|
367 |
self.saved_input = State()
|
368 |
self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
|
369 |
|
370 |
-
self._setup_events()
|
371 |
self._setup_api()
|
372 |
|
373 |
def _set_conversation_id(self):
|
@@ -384,15 +381,15 @@ class ChatInterface(Blocks):
|
|
384 |
with self.data_file.open("a") as f:
|
385 |
f.write(json.dumps(feedback))
|
386 |
|
387 |
-
def _setup_events(self) -> None:
|
388 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
389 |
-
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=
|
390 |
submit_triggers_one = (
|
391 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
392 |
)
|
393 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
394 |
if self.submit_btn_two:
|
395 |
-
submit_fn_two = functools.partial(submit_fn_one, n_generations=
|
396 |
submit_triggers_two = [self.submit_btn_two.click]
|
397 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
398 |
for _fn, _triggers in submit_tuples:
|
|
|
11 |
import random
|
12 |
import re
|
13 |
import uuid
|
14 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
15 |
|
16 |
import anyio
|
17 |
from gradio.blocks import Blocks
|
|
|
27 |
get_component_instance,
|
28 |
)
|
29 |
from gradio.events import Dependency, on
|
30 |
+
from gradio.helpers import Error, Info
|
31 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
32 |
+
from gradio.helpers import special_args
|
33 |
from gradio.layouts import Accordion, Group, Row
|
34 |
from gradio.routes import Request
|
35 |
from gradio.themes import ThemeClass as Theme
|
|
|
66 |
self,
|
67 |
fn: Callable,
|
68 |
*,
|
69 |
+
prefence_technique: str = None,
|
70 |
min_turns: int = 1,
|
71 |
max_turns: int = 1,
|
72 |
repo_id: None | str,
|
|
|
127 |
raise ValueError("`max_turns` should be larger than `min_turns`")
|
128 |
self.max_turns = max_turns
|
129 |
self.min_turns = min_turns
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
|
132 |
+
if prefence_technique not in optional_techniques:
|
133 |
raise ValueError(f"Supported techniques are {optional_techniques}")
|
134 |
submit_btn_one = "Generate"
|
135 |
submit_btn_two = None
|
|
|
141 |
stop_btn = "Stop"
|
142 |
undo_btn = "↩️ Undo"
|
143 |
clear_btn = "🗑️ Clear"
|
144 |
+
n_generations = 1
|
145 |
+
if "kto" == prefence_technique:
|
146 |
submit_btn_good = "The response 👍"
|
147 |
submit_btn_bad = "The response 👎"
|
148 |
+
if prefence_technique in ["dpo", "simpo", "rlhf", "orpo"]:
|
149 |
+
n_generations = 2
|
150 |
submit_btn_a = "A is better than B"
|
151 |
submit_btn_b = "B is better than A"
|
152 |
submit_btn_ab = "A and B are similar"
|
|
|
364 |
self.saved_input = State()
|
365 |
self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
|
366 |
|
367 |
+
self._setup_events(n_generations)
|
368 |
self._setup_api()
|
369 |
|
370 |
def _set_conversation_id(self):
|
|
|
381 |
with self.data_file.open("a") as f:
|
382 |
f.write(json.dumps(feedback))
|
383 |
|
384 |
+
def _setup_events(self, n_generations) -> None:
|
385 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
386 |
+
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=n_generations)
|
387 |
submit_triggers_one = (
|
388 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
389 |
)
|
390 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
391 |
if self.submit_btn_two:
|
392 |
+
submit_fn_two = functools.partial(submit_fn_one, n_generations=n_generations)
|
393 |
submit_triggers_two = [self.submit_btn_two.click]
|
394 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
395 |
for _fn, _triggers in submit_tuples:
|