OlivierDehaene commited on
Commit
040958c
1 Parent(s): ef366f8

truncate history

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -6,13 +6,13 @@ from text_generation import Client, InferenceAPIClient
6
 
7
 
8
  def get_client(model: str):
9
- if model == "Rallio67/joi_20B_instruct_alpha":
10
  return Client(os.getenv("API_URL"))
11
  return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
12
 
13
 
14
  def get_usernames(model: str):
15
- if model == "Rallio67/joi_20B_instruct_alpha":
16
  return "User: ", "Joi: "
17
  return "User: ", "Assistant: "
18
 
@@ -48,7 +48,8 @@ def predict(
48
  inputs = user_name + inputs
49
 
50
  total_inputs = "".join(past) + inputs + "\n\n" + assistant_name
51
- print(total_inputs)
 
52
 
53
  partial_words = ""
54
 
@@ -59,8 +60,8 @@ def predict(
59
  repetition_penalty=repetition_penalty,
60
  watermark=watermark,
61
  temperature=temperature,
62
- max_new_tokens=1000,
63
- stop_sequences=["User:"],
64
  )):
65
  if response.token.special:
66
  continue
@@ -105,9 +106,9 @@ with gr.Blocks(
105
  gr.HTML(title)
106
  with gr.Column(elem_id="col_container"):
107
  model = gr.Radio(
108
- value="Rallio67/joi_20B_instruct_alpha",
109
  choices=[
110
- "Rallio67/joi_20B_instruct_alpha",
111
  "google/flan-t5-xxl",
112
  "google/flan-ul2",
113
  "bigscience/bloom",
 
6
 
7
 
8
  def get_client(model: str):
9
+ if model == "Rallio67/joi2_20B_instruct_alpha":
10
  return Client(os.getenv("API_URL"))
11
  return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
12
 
13
 
14
  def get_usernames(model: str):
15
+ if model == "Rallio67/joi2_20B_instruct_alpha":
16
  return "User: ", "Joi: "
17
  return "User: ", "Assistant: "
18
 
 
48
  inputs = user_name + inputs
49
 
50
  total_inputs = "".join(past) + inputs + "\n\n" + assistant_name
51
+ # truncate total_inputs
52
+ total_inputs = total_inputs[-1000:]
53
 
54
  partial_words = ""
55
 
 
60
  repetition_penalty=repetition_penalty,
61
  watermark=watermark,
62
  temperature=temperature,
63
+ max_new_tokens=500,
64
+ stop_sequences=[user_name.rstrip()],
65
  )):
66
  if response.token.special:
67
  continue
 
106
  gr.HTML(title)
107
  with gr.Column(elem_id="col_container"):
108
  model = gr.Radio(
109
+ value="Rallio67/joi2_20B_instruct_alpha",
110
  choices=[
111
+ "Rallio67/joi2_20B_instruct_alpha",
112
  "google/flan-t5-xxl",
113
  "google/flan-ul2",
114
  "bigscience/bloom",