justinhl commited on
Commit
572cacf
1 Parent(s): e735ab2

Update hybrid_pipe.py

Browse files
Files changed (1) hide show
  1. hybrid_pipe.py +3 -2
hybrid_pipe.py CHANGED
@@ -61,8 +61,9 @@ class HybridQAModel(nn.Module, PyTorchModelHubMixin):
61
 
62
  def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
63
  max_input_length = min(tokenizer.model_max_length, model.config.max_length)
 
64
  messages = [
65
- {"role": "user", "content": input_text + " Do not output anything but the question's answer."}
66
  ]
67
  input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
68
  generated_outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.5, output_scores=True, return_dict_in_generate=True)
@@ -74,7 +75,7 @@ class HybridQAModel(nn.Module, PyTorchModelHubMixin):
74
  average_confidence = sum(max_confidence_scores) / len(max_confidence_scores) # Calculate average confidence
75
 
76
  decoded_output = tokenizer.decode(generated_outputs.sequences[0], skip_special_tokens=True)
77
- final_output = decoded_output.replace("<|im_end|>", "").split("\n")[-1]
78
  average_confidence, final_output
79
  return final_output, average_confidence
80
 
 
61
 
62
  def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
63
  max_input_length = min(tokenizer.model_max_length, model.config.max_length)
64
+ input_text += " Do not output anything but the question's answer."
65
  messages = [
66
+ {"role": "user", "content": input_text}
67
  ]
68
  input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
69
  generated_outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.5, output_scores=True, return_dict_in_generate=True)
 
75
  average_confidence = sum(max_confidence_scores) / len(max_confidence_scores) # Calculate average confidence
76
 
77
  decoded_output = tokenizer.decode(generated_outputs.sequences[0], skip_special_tokens=True)
78
+ final_output = decoded_output[len(input):].split("\n")[-1]
79
  average_confidence, final_output
80
  return final_output, average_confidence
81