Spaces:
Runtime error
Runtime error
| import os | |
| import ast | |
| import json | |
| import time | |
| import argparse | |
| import traceback | |
| from tqdm import tqdm | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from openai import AzureOpenAI | |
| def init(): | |
| client = AzureOpenAI( | |
| azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), | |
| api_key=os.getenv("AZURE_OPENAI_KEY"), | |
| api_version="2024-02-15-preview" | |
| ) | |
| return client | |
| def interaction(client, message_text): | |
| completion = client.chat.completions.create( | |
| model=os.getenv("AZURE_OPENAI_DEPLOYNAME"), | |
| messages = message_text, | |
| temperature=0.7, | |
| max_tokens=800, | |
| top_p=0.95, | |
| frequency_penalty=0, | |
| presence_penalty=0, | |
| stop=None | |
| ) | |
| return completion | |
| def prompt_gpt(question, answer, pred, key, qa_set, output_dir): | |
| message = [ | |
| { | |
| "role": "system", | |
| "content": | |
| "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " | |
| "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" | |
| "------" | |
| "##INSTRUCTIONS: " | |
| "- Focus on the meaningful match between the predicted answer and the correct answer.\n" | |
| "- Consider synonyms or paraphrases as valid matches.\n" | |
| "- Evaluate the correctness of the prediction compared to the answer." | |
| }, | |
| { | |
| "role": "user", | |
| "content": | |
| "Please evaluate the following video-based question-answer pair:\n\n" | |
| f"Question: {question}\n" | |
| f"Correct Answer: {answer}\n" | |
| f"Predicted Answer: {pred}\n\n" | |
| "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " | |
| "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." | |
| "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " | |
| "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." | |
| } | |
| ] | |
| completion = interaction(client, message) | |
| # Convert response to a Python dictionary. | |
| response_message = completion.choices[0].message.content | |
| response_dict = ast.literal_eval(response_message) | |
| result_qa_pair = [response_dict, qa_set] | |
| # # Save the question-answer pairs to a json file. | |
| with open(f"{output_dir}/{key}.json", "w") as f: | |
| json.dump(result_qa_pair, f) | |
| def annotate(task_arg): | |
| """ | |
| Evaluates question and answer pairs using GPT-3 | |
| Returns a score for correctness. | |
| """ | |
| prediction_set, caption_files, output_dir, args = task_arg | |
| for file in tqdm(caption_files): | |
| key = file[:-5] # Strip file extension | |
| qa_set = prediction_set[key] | |
| question = qa_set['q'] | |
| answer = qa_set['a'] | |
| pred = qa_set['p'] | |
| try: | |
| prompt_gpt(question, answer, pred, key, qa_set, output_dir) | |
| except Exception as e: | |
| prompt_gpt(question, answer, pred[:50], key, qa_set, output_dir) | |
| traceback.print_exc() | |
| time.sleep(1) | |
| def main(args): | |
| file = open(args.pred_path) | |
| new_pred_contents = [eval(i.strip()) for i in file.readlines()] | |
| # Generating list of id's and corresponding files | |
| id_list = [x['id'] for x in new_pred_contents] | |
| caption_files = [f"{id}.json" for id in id_list] | |
| output_dir = args.output_dir | |
| # Generate output directory if not exists. | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| # Preparing dictionary of question-answer sets | |
| prediction_set = {} | |
| for sample in new_pred_contents: | |
| id = sample['id'] | |
| question = sample['question'] | |
| answer = sample['answer'] | |
| pred = sample['pred'] | |
| qa_set = {"q": question, "a": answer, "p": pred} | |
| prediction_set[id] = qa_set | |
| num_tasks = args.num_tasks | |
| # While loop to ensure that all captions are processed. | |
| while True: | |
| try: | |
| # Files that have not been processed yet. | |
| completed_files = os.listdir(output_dir) | |
| print(f"completed_files: {len(completed_files)}") | |
| # Files that have not been processed yet. | |
| incomplete_files = [f for f in caption_files if f not in completed_files] | |
| print(f"incomplete_files: {len(incomplete_files)}") | |
| # Break the loop when there are no incomplete files | |
| if len(incomplete_files) == 0: | |
| break | |
| if len(incomplete_files) <= num_tasks: | |
| num_tasks = 1 | |
| # Split tasks into parts. | |
| part_len = len(incomplete_files) // num_tasks | |
| all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] | |
| task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] | |
| # Use a pool of workers to process the files in parallel. | |
| with ThreadPoolExecutor(max_workers=args.num_tasks) as executor: | |
| list(tqdm(executor.map(annotate, task_args), total=len(task_args))) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # multiprocessing to combine json files | |
| def combine_json(file_name): | |
| file_path = os.path.join(output_dir, file_name) | |
| with open(file_path, "r") as json_file: | |
| content = json.load(json_file) | |
| return (file_name[:-5], content) | |
| files = os.listdir(output_dir) | |
| with ThreadPoolExecutor(max_workers=64) as executor: | |
| combined_contents = list(tqdm(executor.map(combine_json, files), total=len(files))) | |
| # Calculate average score and accuracy | |
| score_sum = 0 | |
| count = 0 | |
| yes_count = 0 | |
| no_count = 0 | |
| for key, result in tqdm(combined_contents): | |
| try: | |
| # Computing score | |
| count += 1 | |
| score_match = result[0]['score'] | |
| score = int(score_match) | |
| score_sum += score | |
| # Computing accuracy | |
| pred = result[0]['pred'] | |
| if "yes" in pred.lower(): | |
| yes_count += 1 | |
| elif "no" in pred.lower(): | |
| no_count += 1 | |
| except: | |
| print(result) | |
| average_score = score_sum / count | |
| accuracy = yes_count / (yes_count + no_count) | |
| print("Yes count:", yes_count) | |
| print("No count:", no_count) | |
| print("Accuracy:", accuracy) | |
| print("Average score:", average_score) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") | |
| parser.add_argument("--pred-path", required=True, help="The path to file containing prediction.") | |
| parser.add_argument("--output-dir", required=True, help="The path to save annotation json files.") | |
| parser.add_argument("--output-json", required=True, help="The path to save annotation final combined json file.") | |
| parser.add_argument("--num-tasks", required=True, type=int, help="Number of splits.") | |
| parser.add_argument("--api-key", required=True, type=str, help="Azure Openai API key.") | |
| parser.add_argument("--api-endpoint", required=True, type=str, help="Azure Openai API endpoint.") | |
| parser.add_argument("--api-deployname", required=True, type=str, help="Azure Openai API deployname.") | |
| args = parser.parse_args() | |
| # Set the OpenAI API key. | |
| os.environ["AZURE_OPENAI_KEY"] = args.api_key | |
| os.environ["AZURE_OPENAI_ENDPOINT"] = args.api_endpoint | |
| os.environ["AZURE_OPENAI_DEPLOYNAME"] = args.api_deployname | |
| client = init() | |
| main(args) | |