File size: 5,210 Bytes
d1eeaf0
 
 
85624c2
d1eeaf0
85624c2
d1eeaf0
 
 
85624c2
7a6b1cb
d1eeaf0
7a6b1cb
85624c2
d1eeaf0
7a6b1cb
d1eeaf0
 
85624c2
 
 
 
 
 
d1eeaf0
7a6b1cb
85624c2
 
d1eeaf0
 
 
 
 
 
 
 
85624c2
d1eeaf0
 
 
 
 
 
 
85624c2
d1eeaf0
 
 
 
 
 
 
85624c2
d1eeaf0
 
 
 
 
 
 
 
 
 
 
 
 
85624c2
 
 
 
 
 
d1eeaf0
85624c2
 
d1eeaf0
 
85624c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1eeaf0
85624c2
 
 
 
 
 
d1eeaf0
85624c2
 
 
 
d1eeaf0
 
85624c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import tqdm
import logging
from scripts.get_prediction_file import get_prediction_file
from scripts.groq_client import GroqClient
from scripts.helper import adaptive_delay, ensure_directory_exists, load_used_data, update_config
from scripts.prompt import get_factual_prompt

def evaluate_factual_robustness(config):
    """Evaluates negative rejection for a given model under multiple correct_rate/noise_rate conditions."""
    model_name = config['model_name']
    
    if model_name in config['models']:
        model = GroqClient(plm=model_name)
    else:
        logging.warning(f"Skipping unknown model: {model_name}")
        return

    # Define the conditions to test
    conditions = [
        {"correct_rate": 1.0, "noise_rate": 0.2, "label": "factual_only"},  # factual documents with some noisy documents
        {"correct_rate": 0.0, "noise_rate": 0.4, "label": "counterfactual"}  # Counterfactual + noise
    ]

    base_path = "results/Counterfactual Robustness"
    result_file = f"{base_path}/scores_{config['output_file_extension']}.json"
    final_scores = {"conditions": []}

    def process_query(model, data, used_data, output_file):
        """Processes a single query, generates evaluation, and writes the result."""
        if data['id'] in used_data and data['query'] == used_data[data['id']]['query'] and data['ans'] == used_data[data['id']]['ans']:
            output_file.write(json.dumps(used_data[data['id']], ensure_ascii=False) + '\n')
            return used_data[data['id']]

        try:
            instruction = get_factual_prompt(data['query'], data['prediction'])
            #eval_model = GroqClient(plm='llama3-70b-8192')
            for attempt in range(1, 4):
                evaluation = model.generate(instruction)
                if evaluation:
                    break
                adaptive_delay(attempt)
            
            data['evaluation'] = evaluation
            logging.info(f"Model Response for Factual robustness: {evaluation}")
            output_file.write(json.dumps(data, ensure_ascii=False) + '\n')
            return data

        except Exception as e:
            print(f"Error processing query: {e}")
            return None

    def calculate_scores(results, condition):
        """Calculates and returns rejection rates and other metrics."""
        rejecttt = 0
        tt = 0
        correct_tt = 0
        for i in results:
            if "has identified" in i['evaluation'] or "Yes" in i['evaluation']:
                rejecttt += 1
                if 0 not in i['label'] and 1 in i['label']:
                    correct_tt += 1
            if 0 not in i['label'] and 1 in i['label']:
                tt += 1

        scores = {
            'reject_rate': rejecttt / len(results) if len(results) > 0 else 0, #Error Detection Rate (ED)
            'all_rate': tt / len(results) if len(results) > 0 else 0,
            'correct_rate': correct_tt / rejecttt if rejecttt > 0 else 0, #Error Correction Rate (CR)
            'tt': tt,
            'rejecttt': rejecttt,
            'correct_tt': correct_tt,
            'nums': len(results),
            'noise_rate': condition['noise_rate'],
            'condition_label': condition['label']
        }
        return scores
 
    for condition in conditions:
        logging.info(f"\nEvaluating condition: {condition['label']} (correct_rate={condition['correct_rate']}, noise_rate={condition['noise_rate']})")
        
        # Update config with current condition's noise_rate
        config['noise_rate'] = condition['noise_rate']
        #config['passage_num'] = 10
        update_config(config)
       
        # File paths with condition-specific suffixes
        pred_file = get_prediction_file(config, condition['correct_rate'])
        output_file = f"{base_path}/output_{config['output_file_extension']}.json"
       
        ensure_directory_exists(output_file)
        
        logging.info(f"Factual pred file for {condition['label']}: {pred_file}")

        # Load or recalculate data
        used_data = []
        results = []
        if config['UsePreCalculatedValue']:
            logging.info(f"Trying to use pre-calculated values for {condition['label']}")
            used_data = load_used_data(output_file)
        else:
            logging.info(f"Recalculating the metrics for {condition['label']}...")

        with open(output_file, 'w', encoding='utf-8') as f_out, open(pred_file, 'r', encoding='utf-8') as f_eval:
            for line in tqdm.tqdm(f_eval):
                data = json.loads(line)
                processed_data = process_query(model, data, used_data, f_out)
                if processed_data:
                    results.append(processed_data)

        # Compute and save scores
        scores = calculate_scores(results, condition)
        final_scores["conditions"].append(scores)
        logging.info(f"Counterfactual Robustness Score for {condition['label']}: {scores}")

    with open(result_file, 'w', encoding='utf-8') as f_result:
        json.dump(final_scores, f_result, ensure_ascii=False, indent=4)