Add latest OA Pythia

#13
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -23,7 +23,7 @@ def get_usernames(model: str):
23
  Returns:
24
  (str, str, str, str): pre-prompt, username, bot name, separator
25
  """
26
- if model == "OpenAssistant/oasst-sft-1-pythia-12b":
27
  return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
28
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
29
  return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
@@ -65,7 +65,7 @@ def predict(
65
 
66
  partial_words = ""
67
 
68
- if model == "OpenAssistant/oasst-sft-1-pythia-12b":
69
  iterator = client.generate_stream(
70
  total_inputs,
71
  typical_p=typical_p,
@@ -122,7 +122,7 @@ def radio_on_change(
122
  repetition_penalty,
123
  watermark,
124
  ):
125
- if value == "OpenAssistant/oasst-sft-1-pythia-12b":
126
  typical_p = typical_p.update(value=0.2, visible=True)
127
  top_p = top_p.update(visible=False)
128
  top_k = top_k.update(visible=False)
@@ -187,8 +187,9 @@ with gr.Blocks(
187
  gr.Markdown(text_generation_inference, visible=True)
188
  with gr.Column(elem_id="col_container"):
189
  model = gr.Radio(
190
- value="OpenAssistant/oasst-sft-1-pythia-12b",
191
  choices=[
 
192
  "OpenAssistant/oasst-sft-1-pythia-12b",
193
  # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
194
  "google/flan-t5-xxl",
23
  Returns:
24
  (str, str, str, str): pre-prompt, username, bot name, separator
25
  """
26
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
27
  return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
28
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
29
  return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
65
 
66
  partial_words = ""
67
 
68
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
69
  iterator = client.generate_stream(
70
  total_inputs,
71
  typical_p=typical_p,
122
  repetition_penalty,
123
  watermark,
124
  ):
125
+ if value in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
126
  typical_p = typical_p.update(value=0.2, visible=True)
127
  top_p = top_p.update(visible=False)
128
  top_k = top_k.update(visible=False)
187
  gr.Markdown(text_generation_inference, visible=True)
188
  with gr.Column(elem_id="col_container"):
189
  model = gr.Radio(
190
+ value="OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
191
  choices=[
192
+ "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
193
  "OpenAssistant/oasst-sft-1-pythia-12b",
194
  # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
195
  "google/flan-t5-xxl",