davidberenstein1957 HF staff commited on
Commit
88a6c27
·
1 Parent(s): 678a3bb

Update preference_technique

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. chat_interface_preference.py +13 -16
app.py CHANGED
@@ -70,10 +70,10 @@ def generate(
70
 
71
  chat_interface = ChatInterface(
72
  fn=generate,
73
- prefence_techniques="kto",
74
  min_turns=1,
75
  max_turns=10,
76
- repo_id="llm-human-feedback-collector-chat-interface-dpo",
77
  chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
78
  cache_examples=False,
79
  additional_inputs=[
 
70
 
71
  chat_interface = ChatInterface(
72
  fn=generate,
73
+ prefence_technique="kto",
74
  min_turns=1,
75
  max_turns=10,
76
+ repo_id="llm-human-feedback-collector-chat-interface-kto",
77
  chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
78
  cache_examples=False,
79
  additional_inputs=[
chat_interface_preference.py CHANGED
@@ -11,7 +11,7 @@ import json
11
  import random
12
  import re
13
  import uuid
14
- from typing import AsyncGenerator, Callable, List, Literal, Union, cast
15
 
16
  import anyio
17
  from gradio.blocks import Blocks
@@ -27,8 +27,9 @@ from gradio.components import (
27
  get_component_instance,
28
  )
29
  from gradio.events import Dependency, on
30
- from gradio.helpers import Error, Info, special_args
31
  from gradio.helpers import create_examples as Examples # noqa: N812
 
32
  from gradio.layouts import Accordion, Group, Row
33
  from gradio.routes import Request
34
  from gradio.themes import ThemeClass as Theme
@@ -65,7 +66,7 @@ class ChatInterface(Blocks):
65
  self,
66
  fn: Callable,
67
  *,
68
- prefence_techniques: str | List[str] | None = None,
69
  min_turns: int = 1,
70
  max_turns: int = 1,
71
  repo_id: None | str,
@@ -126,14 +127,9 @@ class ChatInterface(Blocks):
126
  raise ValueError("`max_turns` should be larger than `min_turns`")
127
  self.max_turns = max_turns
128
  self.min_turns = min_turns
129
- if isinstance(prefence_techniques, str):
130
- prefence_techniques = [prefence_techniques]
131
- elif prefence_techniques is None:
132
- prefence_techniques = ["sft"]
133
- self.prefence_techniques = [technique.lower() for technique in prefence_techniques]
134
 
135
  optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
136
- if any([technique for technique in self.prefence_techniques if technique not in optional_techniques]):
137
  raise ValueError(f"Supported techniques are {optional_techniques}")
138
  submit_btn_one = "Generate"
139
  submit_btn_two = None
@@ -145,11 +141,12 @@ class ChatInterface(Blocks):
145
  stop_btn = "Stop"
146
  undo_btn = "↩️ Undo"
147
  clear_btn = "🗑️ Clear"
148
- if "kto" in prefence_techniques:
 
149
  submit_btn_good = "The response 👍"
150
  submit_btn_bad = "The response 👎"
151
- if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
152
- submit_btn_two = None
153
  submit_btn_a = "A is better than B"
154
  submit_btn_b = "B is better than A"
155
  submit_btn_ab = "A and B are similar"
@@ -367,7 +364,7 @@ class ChatInterface(Blocks):
367
  self.saved_input = State()
368
  self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
369
 
370
- self._setup_events()
371
  self._setup_api()
372
 
373
  def _set_conversation_id(self):
@@ -384,15 +381,15 @@ class ChatInterface(Blocks):
384
  with self.data_file.open("a") as f:
385
  f.write(json.dumps(feedback))
386
 
387
- def _setup_events(self) -> None:
388
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
389
- submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=2)
390
  submit_triggers_one = (
391
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
392
  )
393
  submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
394
  if self.submit_btn_two:
395
- submit_fn_two = functools.partial(submit_fn_one, n_generations=1)
396
  submit_triggers_two = [self.submit_btn_two.click]
397
  submit_tuples.append((submit_fn_two, submit_triggers_two))
398
  for _fn, _triggers in submit_tuples:
 
11
  import random
12
  import re
13
  import uuid
14
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
15
 
16
  import anyio
17
  from gradio.blocks import Blocks
 
27
  get_component_instance,
28
  )
29
  from gradio.events import Dependency, on
30
+ from gradio.helpers import Error, Info
31
  from gradio.helpers import create_examples as Examples # noqa: N812
32
+ from gradio.helpers import special_args
33
  from gradio.layouts import Accordion, Group, Row
34
  from gradio.routes import Request
35
  from gradio.themes import ThemeClass as Theme
 
66
  self,
67
  fn: Callable,
68
  *,
69
+ prefence_technique: str = None,
70
  min_turns: int = 1,
71
  max_turns: int = 1,
72
  repo_id: None | str,
 
127
  raise ValueError("`max_turns` should be larger than `min_turns`")
128
  self.max_turns = max_turns
129
  self.min_turns = min_turns
 
 
 
 
 
130
 
131
  optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
132
+ if prefence_technique not in optional_techniques:
133
  raise ValueError(f"Supported techniques are {optional_techniques}")
134
  submit_btn_one = "Generate"
135
  submit_btn_two = None
 
141
  stop_btn = "Stop"
142
  undo_btn = "↩️ Undo"
143
  clear_btn = "🗑️ Clear"
144
+ n_generations = 1
145
+ if "kto" == prefence_technique:
146
  submit_btn_good = "The response 👍"
147
  submit_btn_bad = "The response 👎"
148
+ if prefence_technique in ["dpo", "simpo", "rlhf", "orpo"]:
149
+ n_generations = 2
150
  submit_btn_a = "A is better than B"
151
  submit_btn_b = "B is better than A"
152
  submit_btn_ab = "A and B are similar"
 
364
  self.saved_input = State()
365
  self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
366
 
367
+ self._setup_events(n_generations)
368
  self._setup_api()
369
 
370
  def _set_conversation_id(self):
 
381
  with self.data_file.open("a") as f:
382
  f.write(json.dumps(feedback))
383
 
384
+ def _setup_events(self, n_generations) -> None:
385
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
386
+ submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=n_generations)
387
  submit_triggers_one = (
388
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
389
  )
390
  submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
391
  if self.submit_btn_two:
392
+ submit_fn_two = functools.partial(submit_fn_one, n_generations=n_generations)
393
  submit_triggers_two = [self.submit_btn_two.click]
394
  submit_tuples.append((submit_fn_two, submit_triggers_two))
395
  for _fn, _triggers in submit_tuples: