danseith commited on
Commit
2ce1788
1 Parent(s): ca69fee

Implemented temperature scaling and changed output sampling to uniform.

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -71,7 +71,7 @@ class TempScalePipe(FillMaskPipeline):
71
  masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
72
  # Fill mask pipeline supports only one ${mask_token} per sample
73
 
74
- logits = outputs[0, masked_index, :] / 1.2
75
  probs = logits.softmax(dim=-1)
76
  sampling = False
77
  if sampling:
@@ -114,17 +114,18 @@ PIPELINE_REGISTRY.register_pipeline(
114
  )
115
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
116
 
117
- def unmask(text, temp):
118
  # text = add_mask(text)
119
  split_text = text.split()
120
- res = scrambler(text)
121
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
122
  out = {item["token_str"]: item["score"] for item in res}
123
  score_to_str = {out[k]:k for k in out.keys()}
124
- print(score_to_str)
125
- print(out)
126
  score_list = list(score_to_str.keys())
127
- idx = np.argmax(np.random.multinomial(1, score_list, 1))
 
 
 
128
  score = score_list[idx]
129
  new_token = score_to_str[score]
130
  split_text[mask_pos] = new_token
@@ -132,7 +133,7 @@ def unmask(text, temp):
132
 
133
  textbox = gr.Textbox(label="Type language here", lines=5)
134
  textbox2 = gr.Textbox(placeholder="Type here...", lines=4)
135
- temp_slider = gr.Slider(1.0, 1.5, value=1.0, label='Creativity')
136
 
137
  demo = gr.Interface(
138
  fn=unmask,
 
71
  masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
72
  # Fill mask pipeline supports only one ${mask_token} per sample
73
 
74
+ logits = outputs[0, masked_index, :] / temp
75
  probs = logits.softmax(dim=-1)
76
  sampling = False
77
  if sampling:
 
114
  )
115
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
116
 
117
+ def unmask(text, temp, sampling='uniform'):
118
  # text = add_mask(text)
119
  split_text = text.split()
120
+ res = scrambler(text, temp=temp, top_k=3)
121
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
122
  out = {item["token_str"]: item["score"] for item in res}
123
  score_to_str = {out[k]:k for k in out.keys()}
 
 
124
  score_list = list(score_to_str.keys())
125
+ if sampling == 'multi':
126
+ idx = np.argmax(np.random.multinomial(1, score_list, 1))
127
+ else:
128
+ idx = np.random.randint(0, len(score_list))
129
  score = score_list[idx]
130
  new_token = score_to_str[score]
131
  split_text[mask_pos] = new_token
 
133
 
134
  textbox = gr.Textbox(label="Type language here", lines=5)
135
  textbox2 = gr.Textbox(placeholder="Type here...", lines=4)
136
+ temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='Temperature Scale')
137
 
138
  demo = gr.Interface(
139
  fn=unmask,