Update hybrid_pipe.py
Browse files- hybrid_pipe.py +3 -2
hybrid_pipe.py
CHANGED
@@ -61,8 +61,9 @@ class HybridQAModel(nn.Module, PyTorchModelHubMixin):
|
|
61 |
|
62 |
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
|
63 |
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
|
|
|
64 |
messages = [
|
65 |
-
{"role": "user", "content": input_text
|
66 |
]
|
67 |
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
68 |
generated_outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.5, output_scores=True, return_dict_in_generate=True)
|
@@ -74,7 +75,7 @@ class HybridQAModel(nn.Module, PyTorchModelHubMixin):
|
|
74 |
average_confidence = sum(max_confidence_scores) / len(max_confidence_scores) # Calculate average confidence
|
75 |
|
76 |
decoded_output = tokenizer.decode(generated_outputs.sequences[0], skip_special_tokens=True)
|
77 |
-
final_output = decoded_output
|
78 |
average_confidence, final_output
|
79 |
return final_output, average_confidence
|
80 |
|
|
|
61 |
|
62 |
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
|
63 |
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
|
64 |
+
input_text += " Do not output anything but the question's answer."
|
65 |
messages = [
|
66 |
+
{"role": "user", "content": input_text}
|
67 |
]
|
68 |
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
69 |
generated_outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.5, output_scores=True, return_dict_in_generate=True)
|
|
|
75 |
average_confidence = sum(max_confidence_scores) / len(max_confidence_scores) # Calculate average confidence
|
76 |
|
77 |
decoded_output = tokenizer.decode(generated_outputs.sequences[0], skip_special_tokens=True)
|
78 |
+
final_output = decoded_output[len(input):].split("\n")[-1]
|
79 |
average_confidence, final_output
|
80 |
return final_output, average_confidence
|
81 |
|