yhavinga commited on
Commit
0d30451
β€’
1 Parent(s): 4a76d43

Add diversity penalty, pin tokenizers on older version

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. generator.py +0 -1
  3. requirements.txt +8 -7
app.py CHANGED
@@ -122,7 +122,7 @@ def main():
122
  "Num beam groups", min_value=1, max_value=10, value=1
123
  )
124
  length_penalty = st.sidebar.number_input(
125
- "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
126
  )
127
  diversity_penalty = st.sidebar.number_input(
128
  "Diversity penalty", min_value=0.0, max_value=2.0, value=0.1, step=0.1
@@ -136,7 +136,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
136
  "num_beams": num_beams,
137
  "num_beam_groups": num_beam_groups,
138
  "diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0,
139
- "length_penalty": length_penalty,
140
  "early_stopping": True,
141
  }
142
 
 
122
  "Num beam groups", min_value=1, max_value=10, value=1
123
  )
124
  length_penalty = st.sidebar.number_input(
125
+ "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
126
  )
127
  diversity_penalty = st.sidebar.number_input(
128
  "Diversity penalty", min_value=0.0, max_value=2.0, value=0.1, step=0.1
 
136
  "num_beams": num_beams,
137
  "num_beam_groups": num_beam_groups,
138
  "diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0,
139
+ "length_penalty": length_penalty if num_beams > 1 else 1.0,
140
  "early_stopping": True,
141
  }
142
 
generator.py CHANGED
@@ -1,4 +1,3 @@
1
- import _thread
2
  import os
3
  import re
4
 
 
 
1
  import os
2
  import re
3
 
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
- streamlit
2
- torch
3
- transformers
 
4
  langdetect
5
  psutil
6
- jax<0.4.0
7
- jaxlib<0.4.0
8
- chex>=0.1.4
9
- flax<0.6.0
10
  sentencepiece
11
  nltk
 
1
+ streamlit~=1.25.0
2
+ torch~=2.0.0
3
+ transformers~=4.30.0
4
+ tokenizers~=0.13.3
5
  langdetect
6
  psutil
7
+ jax==0.4.13
8
+ jaxlib==0.4.13
9
+ chex
10
+ flax
11
  sentencepiece
12
  nltk