|
import logging
|
|
from scripts.helper import adaptive_delay, load_dataset, load_used_data
|
|
from scripts.process_data import process_data
|
|
from scripts.groq_client import GroqClient
|
|
from scripts.prediction import predict
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
def get_prediction_result(config, data_file_name, prediction_file_name='', correct_rate = 0):
|
|
results = []
|
|
used_data = []
|
|
dataset = load_dataset(data_file_name)
|
|
modelname = config['model_name']
|
|
num_queries = min(config['num_queries'], len(dataset))
|
|
subdataset = dataset[:num_queries]
|
|
|
|
|
|
if modelname in config['models']:
|
|
model = GroqClient(plm=modelname)
|
|
else:
|
|
logging.warning(f"Skipping unknown model: {modelname}")
|
|
return
|
|
|
|
if config['UsePreCalculatedValue']:
|
|
logging.info(f"Trying to use pre calculated values for report generation")
|
|
used_data = load_used_data(prediction_file_name)
|
|
else:
|
|
logging.info(f"Running evaluation for {num_queries} queries...")
|
|
|
|
|
|
for idx, instance in enumerate(subdataset, start=0):
|
|
if instance['id'] in used_data and instance['query'] == used_data[instance['id']]['query'] and instance['answer'] == used_data[instance['id']]['ans']:
|
|
results.append(used_data[instance['id']])
|
|
continue
|
|
|
|
logging.info(f"Executing Query {idx + 1} for Model: {modelname}")
|
|
query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name, correct_rate)
|
|
|
|
|
|
for attempt in range(1, config['retry_attempts'] + 1):
|
|
label, prediction, factlabel = predict(query, ans, docs, model, "Document:\n{DOCS} \n\nQuestion:\n{QUERY}", 0.7)
|
|
if prediction:
|
|
break
|
|
adaptive_delay(attempt)
|
|
|
|
|
|
is_correct = all(x == 1 for x in label)
|
|
logging.info(f"Model Response: {prediction}")
|
|
logging.info(f"Correctness: {is_correct}")
|
|
|
|
|
|
instance['label'] = label
|
|
new_instance = {
|
|
'id': instance['id'],
|
|
'query': query,
|
|
'ans': ans,
|
|
'label': label,
|
|
'prediction': prediction,
|
|
'docs': docs,
|
|
'noise_rate': config['noise_rate'],
|
|
'factlabel': factlabel
|
|
}
|
|
results.append(new_instance)
|
|
|
|
return results
|
|
|