Spaces:
Running
Running
import json | |
import re | |
import time | |
from transformers import GPT2Tokenizer | |
from utils import model_prompting, f1_score, exact_match_score, get_bert_score | |
from beartype.typing import Any, Dict, List, Tuple, Optional | |
# Initialize tokenizer for token counting (used in cost calculation) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
class LLMEngine: | |
""" | |
A class to manage interactions with multiple language models and evaluate their performance. | |
Handles model selection, querying, cost calculation, and performance evaluation | |
using various metrics for different tasks. | |
""" | |
def __init__(self, llm_names: List[str], llm_description: Dict[str, Dict[str, Any]]): | |
""" | |
Initialize the LLM Engine with available models and their descriptions. | |
Args: | |
llm_names: List of language model names available in the engine | |
llm_description: Dictionary containing model configurations and pricing details | |
Structure: { | |
"model_name": { | |
"model": "api_identifier", | |
"input_price": cost_per_input_token, | |
"output_price": cost_per_output_token, | |
... | |
}, | |
... | |
} | |
""" | |
self.llm_names = llm_names | |
self.llm_description = llm_description | |
def compute_cost(self, llm_idx: int, input_text: str, output_size: int) -> float: | |
""" | |
Calculate the cost of a model query based on input and output token counts. | |
Args: | |
llm_idx: Index of the model in the llm_names list | |
input_text: The input prompt sent to the model | |
output_size: Number of tokens in the model's response | |
Returns: | |
float: The calculated cost in currency units | |
""" | |
# Count input tokens | |
input_size = len(tokenizer(input_text)['input_ids']) | |
# Get pricing information for the selected model | |
llm_name = self.llm_names[llm_idx] | |
input_price = self.llm_description[llm_name]["input_price"] | |
output_price = self.llm_description[llm_name]["output_price"] | |
# Calculate total cost | |
cost = input_size * input_price + output_size * output_price | |
return cost | |
def get_llm_response(self, query: str, llm_idx: int) -> str: | |
""" | |
Send a query to a language model and get its response. | |
Args: | |
query: The prompt text to send to the model | |
llm_idx: Index of the model in the llm_names list | |
Returns: | |
str: The model's text response | |
Note: | |
Includes a retry mechanism with a 2-second delay if the first attempt fails | |
""" | |
llm_name = self.llm_names[llm_idx] | |
model = self.llm_description[llm_name]["model"] | |
try: | |
response = model_prompting(llm_model=model, prompt=query) | |
except: | |
# If the request fails, wait and retry once | |
time.sleep(2) | |
response = model_prompting(llm_model=model, prompt=query) | |
return response | |
def eval(self, prediction: str, ground_truth: str, metric: str) -> float: | |
""" | |
Evaluate the model's prediction against the ground truth using the specified metric. | |
Args: | |
prediction: The model's output text | |
ground_truth: The correct expected answer | |
metric: The evaluation metric to use (e.g., 'em', 'f1_score', 'GSM8K') | |
task_id: Optional identifier for the specific task being evaluated | |
Returns: | |
float: Evaluation score (typically between 0 and 1) | |
""" | |
# Exact match evaluation | |
if metric == 'em': | |
result = exact_match_score(prediction, ground_truth) | |
return float(result) | |
# Multiple choice exact match | |
elif metric == 'em_mc': | |
result = exact_match_score(prediction, ground_truth, normal_method="mc") | |
return float(result) | |
# BERT-based semantic similarity score | |
elif metric == 'bert_score': | |
result = get_bert_score([prediction], [ground_truth]) | |
return result | |
# GSM8K math problem evaluation | |
# Extracts the final answer from the format "<answer>" and checks against ground truth | |
elif metric == 'GSM8K': | |
# Extract the final answer from ground truth (after the "####" delimiter) | |
ground_truth = ground_truth.split("####")[-1].strip() | |
# Look for an answer enclosed in angle brackets <X> | |
match = re.search(r'\<(\d+)\>', prediction) | |
if match: | |
if match.group(1) == ground_truth: | |
return 1 # Correct answer | |
else: | |
return 0 # Incorrect answer | |
else: | |
return 0 # No answer in expected format | |
# F1 score for partial matching (used in QA tasks) | |
elif metric == 'f1_score': | |
f1, prec, recall = f1_score(prediction, ground_truth) | |
return f1 | |
# Default case for unrecognized metrics | |
else: | |
return 0 |