joaogante HF staff commited on
Commit
426c6f1
1 Parent(s): fd9a520
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -44,21 +44,22 @@ if __name__ == "__main__":
44
  generated_ids = outputs.sequences[:, input_length:]
45
  generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
46
 
47
- # On decoder-only models, you might want to initialize the highlighted output with the prompt (wo labels)
 
48
  if model.config.is_encoder_decoder:
49
  highlighted_out = []
50
  else:
51
  input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)
52
- highlighted_out = [(token.replace("_", " "), None) for token in input_tokens]
53
  # Get the (decoded_token, label) pairs for the generated tokens
54
- for token, proba in zip(generated_tokens[0], transition_proba[0]):
55
  this_label = None
56
  assert 0. <= proba <= 1.0
57
  for min_proba, label in probs_to_label:
58
  if proba >= min_proba:
59
  this_label = label
60
  break
61
- highlighted_out.append((token.replace("_", " "), this_label))
62
 
63
  return highlighted_out
64
 
 
44
  generated_ids = outputs.sequences[:, input_length:]
45
  generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
46
 
47
+ # Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct
48
+ # spacing into the final output 👼
49
  if model.config.is_encoder_decoder:
50
  highlighted_out = []
51
  else:
52
  input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)
53
+ highlighted_out = [(token.replace("", " "), None) for token in input_tokens]
54
  # Get the (decoded_token, label) pairs for the generated tokens
55
+ for token, proba in zip(generated_tokens, transition_proba[0]):
56
  this_label = None
57
  assert 0. <= proba <= 1.0
58
  for min_proba, label in probs_to_label:
59
  if proba >= min_proba:
60
  this_label = label
61
  break
62
+ highlighted_out.append((token.replace("", " "), this_label))
63
 
64
  return highlighted_out
65