hibalaz commited on
Commit
696d3e8
1 Parent(s): 482f2f0

Upload model_evaluation.py

Browse files
Files changed (1) hide show
  1. model_evaluation.py +173 -0
model_evaluation.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+ from transformers import pipeline, GPT2Tokenizer
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ # Define paths and models
7
+ filename = "output_country_details.txt"
8
+ retrieval_model_name = 'output/sentence-transformer-finetuned/' #using a prefine-tuned model
9
+ gpt2_model_name = "gpt2"
10
+ csv_file_path = "train_dataset.csv"
11
+ output_csv_file_path = "updated_train_dataset.csv"
12
+ val_csv_file_path = "val_dataset.csv"
13
+ output_val_csv_file_path = "updated_val_csv.csv"
14
+
15
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
16
+
17
+ # Initialize models
18
+ try:
19
+ retrieval_model = SentenceTransformer(retrieval_model_name)
20
+ gpt_model = pipeline("text-generation", model=gpt2_model_name)
21
+ print("Models loaded successfully.")
22
+ except Exception as e:
23
+ print(f"Failed to load models: {e}")
24
+
25
+ def load_and_preprocess_text(filename):
26
+ """
27
+ Load and preprocess text data from a file.
28
+
29
+ Parameters:
30
+ - filename (str): Path to the text file.
31
+
32
+ Returns:
33
+ - list[str]: A list of preprocessed text segments.
34
+ """
35
+ try:
36
+ with open(filename, 'r', encoding='utf-8') as file:
37
+ segments = [line.strip() for line in file if line.strip()]
38
+ print("Text loaded and preprocessed successfully.")
39
+ return segments
40
+ except Exception as e:
41
+ print(f"Failed to load or preprocess text: {e}")
42
+ return []
43
+
44
+ segments = load_and_preprocess_text(filename)
45
+
46
+ def find_relevant_segment(user_query, segments):
47
+ """
48
+ Find the most relevant text segment based on a user query.
49
+
50
+ Parameters:
51
+ - user_query (str): The user's query.
52
+ - segments (list[str]): List of text segments to search within.
53
+
54
+ Returns:
55
+ - str: The most relevant text segment.
56
+ """
57
+ try:
58
+ query_embedding = retrieval_model.encode(user_query)
59
+ segment_embeddings = retrieval_model.encode(segments)
60
+ similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
61
+ best_idx = similarities.argmax()
62
+ return segments[best_idx]
63
+ except Exception as e:
64
+ print(f"Error finding relevant segment: {e}")
65
+ return ""
66
+
67
+ def generate_response(question):
68
+ """
69
+ Generate a response to a given question by finding a relevant text segment and
70
+ using it to generate a more complete answer.
71
+
72
+ Parameters:
73
+ - question (str): The user's question.
74
+
75
+ Returns:
76
+ - str: Generated response.
77
+ """
78
+ relevant_segment = find_relevant_segment(question, segments)
79
+ return generate_response_with_context(question, relevant_segment)
80
+
81
+ def generate_response_with_context(user_query, relevant_segment):
82
+ """
83
+ Generate a response based on a user query and a relevant segment.
84
+
85
+ Parameters:
86
+ - user_query (str): The user's query.
87
+ - relevant_segment (str): A relevant fact or detail.
88
+
89
+ Returns:
90
+ - str: Formatted response incorporating the relevant segment.
91
+ """
92
+ try:
93
+ prompt = f"Thank you for your question! Here is an additional fact about your topic: {relevant_segment}"
94
+ max_tokens = len(tokenizer(prompt)['input_ids']) + 50
95
+ response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
96
+ return clean_up_response(response, relevant_segment)
97
+ except Exception as e:
98
+ print(f"Error generating response: {e}")
99
+ return ""
100
+
101
+ def clean_up_response(response, segment):
102
+ """
103
+ Clean up the generated response to ensure it is tidy and presentable.
104
+
105
+ Parameters:
106
+ - response (str): The initial response generated by the model.
107
+ - segment (str): The segment used to generate the response.
108
+
109
+ Returns:
110
+ - str: A cleaned and formatted response.
111
+ """
112
+ sentences = response.split('.')
113
+ cleaned_sentences = [sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in segment]
114
+ cleaned_response = '. '.join(cleaned_sentences).strip()
115
+ if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
116
+ cleaned_response += "."
117
+ return cleaned_response
118
+
119
+ def process_dataset(csv_file_path, output_csv_file_path):
120
+ """
121
+ Process the dataset by generating responses and evaluating their similarities.
122
+
123
+ Parameters:
124
+ - csv_file_path (str): Path to the CSV file containing the dataset.
125
+ - output_csv_file_path (str): Path where the updated dataset will be saved.
126
+
127
+ Prints:
128
+ - Path to the saved results and the average similarity score.
129
+ """
130
+ df = pd.read_csv(csv_file_path)
131
+ dataset = Dataset.from_pandas(df)
132
+ updated_dataset = add_model_answers(dataset)
133
+ similarities = evaluate_similarity(updated_dataset)
134
+ updated_dataset = updated_dataset.add_column("similarity", similarities)
135
+ results_df = updated_dataset.to_pandas()
136
+ results_df.to_csv(output_csv_file_path, index=False)
137
+ average_similarity = sum(similarities) / len(similarities) if similarities else 0
138
+ print(f"Results saved to {output_csv_file_path}")
139
+ print(f"Average Similarity Score: {average_similarity:.3f}")
140
+
141
+ def add_model_answers(dataset):
142
+ """
143
+ Add generated answers to the dataset.
144
+
145
+ Parameters:
146
+ - dataset (datasets.Dataset): The Hugging Face dataset object.
147
+
148
+ Returns:
149
+ - datasets.Dataset: Updated dataset with added answers.
150
+ """
151
+ answers = [generate_response(q) for q in dataset['Question']]
152
+ dataset = dataset.add_column("Answer", answers)
153
+ return dataset
154
+
155
+ def evaluate_similarity(dataset):
156
+ """
157
+ Evaluate the similarity of generated answers against ground truth answers.
158
+
159
+ Parameters:
160
+ - dataset (datasets.Dataset): The dataset containing both answers and ground truths.
161
+
162
+ Returns:
163
+ - list[float]: List of similarity scores.
164
+ """
165
+ similarities = [util.pytorch_cos_sim(retrieval_model.encode(ans), retrieval_model.encode(gt))[0][0].item()
166
+ for ans, gt in zip(dataset['Answer'], dataset['GroundTruth'])]
167
+ return similarities
168
+
169
+ # Process datasets
170
+ process_dataset(csv_file_path, output_csv_file_path)
171
+ process_dataset(val_csv_file_path, output_val_csv_file_path)
172
+
173
+