frankaging commited on
Commit
86478a5
1 Parent(s): 0a5510e

align with paper

Browse files
Files changed (1) hide show
  1. app.py +3 -27
app.py CHANGED
@@ -101,12 +101,9 @@ def generate(
101
  "max_new_tokens": max_new_tokens,
102
  "eos_token_id": tokenizer.eos_token_id,
103
  "early_stopping": True,
 
104
  "repetition_penalty": repetition_penalty,
105
- "do_sample": True,
106
- "top_p": top_p,
107
- "top_k": top_k,
108
- "temperature": temperature,
109
- "num_beams": 1,
110
  }
111
 
112
  t = Thread(target=reft_model.generate, kwargs=generate_kwargs)
@@ -128,33 +125,12 @@ chat_interface = gr.ChatInterface(
128
  step=1,
129
  value=DEFAULT_MAX_NEW_TOKENS,
130
  ),
131
- gr.Slider(
132
- label="Temperature",
133
- minimum=0.1,
134
- maximum=2.0,
135
- step=0.1,
136
- value=0.3,
137
- ),
138
- gr.Slider(
139
- label="Top-p (nucleus sampling)",
140
- minimum=0.05,
141
- maximum=1.0,
142
- step=0.05,
143
- value=0.9,
144
- ),
145
- gr.Slider(
146
- label="Top-k",
147
- minimum=1,
148
- maximum=1000,
149
- step=1,
150
- value=50,
151
- ),
152
  gr.Slider(
153
  label="Repetition penalty",
154
  minimum=1.0,
155
  maximum=2.0,
156
  step=0.05,
157
- value=1.2,
158
  ),
159
  ],
160
  stop_btn=None,
 
101
  "max_new_tokens": max_new_tokens,
102
  "eos_token_id": tokenizer.eos_token_id,
103
  "early_stopping": True,
104
+ "no_repeat_ngram_size": 5,
105
  "repetition_penalty": repetition_penalty,
106
+ "do_sample": False,
 
 
 
 
107
  }
108
 
109
  t = Thread(target=reft_model.generate, kwargs=generate_kwargs)
 
125
  step=1,
126
  value=DEFAULT_MAX_NEW_TOKENS,
127
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  gr.Slider(
129
  label="Repetition penalty",
130
  minimum=1.0,
131
  maximum=2.0,
132
  step=0.05,
133
+ value=1.1,
134
  ),
135
  ],
136
  stop_btn=None,