Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import time | |
| import json | |
| import nltk | |
| from rank_bm25 import BM25Okapi | |
| import numpy as np | |
| import torch | |
| from vllm import LLM, SamplingParams | |
| from datetime import datetime, timedelta | |
| from itertools import islice | |
| def download_nltk_data(package_name, download_dir='nltk_data'): | |
| # Ensure the download directory exists | |
| os.makedirs(download_dir, exist_ok=True) | |
| # Set NLTK data path | |
| nltk.data.path.append(download_dir) | |
| try: | |
| # Try to find the resource | |
| nltk.data.find(f'tokenizers/{package_name}') | |
| print(f"Package '{package_name}' is already downloaded") | |
| except LookupError: | |
| # If resource isn't found, download it | |
| print(f"Downloading {package_name}...") | |
| nltk.download(package_name, download_dir=download_dir) | |
| print(f"Successfully downloaded {package_name}") | |
| # def format_time(seconds): | |
| # """Format time duration nicely.""" | |
| # return str(timedelta(seconds=round(seconds))) | |
| def claim2prompts(example): | |
| claim = example["claim"] | |
| claim_str = "Example [NUMBER]:||Claim: " + claim + "||Evidence: " | |
| for question in example["questions"]: | |
| q_text = question["question"].strip() | |
| if len(q_text) == 0: | |
| continue | |
| if not q_text[-1] == "?": | |
| q_text += "?" | |
| answer_strings = [] | |
| for a in question["answers"]: | |
| if a["answer_type"] in ["Extractive", "Abstractive"]: | |
| answer_strings.append(a["answer"]) | |
| if a["answer_type"] == "Boolean": | |
| answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) | |
| for a_text in answer_strings: | |
| if not a_text[-1] in [".", "!", ":", "?"]: | |
| a_text += "." | |
| prompt_lookup_str = a_text | |
| this_q_claim_str = claim_str + a_text.strip() + "||Question: " + q_text | |
| yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")[:1500]) | |
| def main(args): | |
| # script_start = time.time() | |
| # start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # print(f"Script started at: {start_time}") | |
| # print(f"Loading model: {args.model}") | |
| download_nltk_data('punkt') | |
| download_nltk_data('punkt_tab') | |
| # Load and prepare reference corpus | |
| # corpus_start = time.time() | |
| with open(args.reference_corpus, "r", encoding="utf-8") as json_file: | |
| train_examples = json.load(json_file) | |
| prompt_corpus, tokenized_corpus = [], [] | |
| for example in train_examples: | |
| for lookup_str, prompt in claim2prompts(example): | |
| entry = nltk.word_tokenize(lookup_str) | |
| tokenized_corpus.append(entry) | |
| prompt_corpus.append(prompt) | |
| prompt_bm25 = BM25Okapi(tokenized_corpus) | |
| # print(f"Reference corpus processed in: {format_time(time.time() - corpus_start)}") | |
| # Initialize vLLM with optimized settings | |
| gpu_count = torch.cuda.device_count() | |
| print(f"Using {gpu_count} GPU{'s' if gpu_count > 1 else ''}") | |
| # model_start = time.time() | |
| llm = LLM( | |
| model=args.model, | |
| tensor_parallel_size=gpu_count, | |
| max_model_len=4096, | |
| gpu_memory_utilization=0.95, | |
| enforce_eager=True, | |
| trust_remote_code=True, | |
| # dtype="half", | |
| ) | |
| llm.get_tokenizer().pad_token = "<|end_of_text|>" | |
| # print(f"Model loaded in: {format_time(time.time() - model_start)}") | |
| sampling_params = SamplingParams( | |
| temperature=0.6, | |
| top_p=0.9, | |
| top_k=1, | |
| skip_special_tokens=False, | |
| max_tokens=512, | |
| stop=['<|end_of_text|>', '</s>', '<|im_end|>', '[INST]', '[/INST]','<|eot_id|>','<|end|>','<|endoftext|>'] | |
| ) | |
| # processing_start = time.time() | |
| # Load target data | |
| target_examples = [] | |
| with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file: | |
| for line in json_file: | |
| target_examples.append(json.loads(line)) | |
| if args.end == -1: | |
| args.end = len(target_examples) | |
| print(f"Processing {args.end} examples") | |
| # Process in batches | |
| with torch.no_grad(): | |
| with open(args.output_questions, "w", encoding="utf-8") as output_file: | |
| for idx in range(0, args.end, args.batch_size): | |
| batch_end = min(idx + args.batch_size, args.end) | |
| current_batch = target_examples[idx:batch_end] | |
| print(f"\nProcessing batch {idx}-{batch_end}...") | |
| for example in current_batch: | |
| # batch_start = time.time() | |
| claim = example["claim"] | |
| claim_id = example["claim_id"] | |
| top_k_sentences_urls = example[f"top_{args.top_k}"] | |
| batch_prompts = [] | |
| batch_metadata = [] | |
| # Prepare all prompts for current example | |
| for sentences_urls in top_k_sentences_urls: | |
| prompt_lookup_str = sentences_urls["sentence"] | |
| url = sentences_urls["url"] | |
| prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) | |
| prompt_n = 10 | |
| prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] | |
| prompt_docs = [prompt_corpus[i] for i in prompt_top_n] | |
| temp_prompt = "\n\n".join(prompt_docs) | |
| for k in range(1, temp_prompt.count("[NUMBER]")+1): | |
| temp_prompt = temp_prompt.replace("[NUMBER]", f"{k}", 1) | |
| claim_prompt = "Your task is to generate a question based on the given claim and evidence. The question should clarify the relationship between the evidence and the claim\n\n" | |
| evidence = prompt_lookup_str.replace("\n", " ") | |
| full_prompt = claim_prompt + temp_prompt + "\n\nNow, generate a question that links the following claim and evidence:" + f"\n\nClaim: {claim}" + f"\nEvidence: {evidence}" | |
| if "OLMo" in args.model: | |
| inputs = [full_prompt] | |
| else: | |
| messages = [{"role":"user", "content":full_prompt}] | |
| inputs = llm.get_tokenizer().apply_chat_template(messages, tokenize=False) | |
| inputs += "<|start_header_id|>assistant<|end_header_id|>\n\nQuestion: " | |
| batch_prompts.append(inputs) | |
| batch_metadata.append((url, prompt_lookup_str)) | |
| # Process batch | |
| outputs = llm.generate(batch_prompts, sampling_params) | |
| # Process outputs | |
| evidence = [] | |
| for output, (url, sent) in zip(outputs, batch_metadata): | |
| question = output.outputs[0].text.strip().split("?")[0].replace("\n", " ") + "?" | |
| evidence.append({ | |
| "question": question, | |
| "answer": sent, | |
| "url": url | |
| }) | |
| # Write results | |
| json_data = { | |
| "claim_id": claim_id, | |
| "claim": claim, | |
| "evidence": evidence | |
| } | |
| output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") | |
| output_file.flush() | |
| # batch_time = time.time() - batch_start | |
| # print(f"Processed example {claim_id}. Time elapsed: {batch_time:.2f}s") | |
| # Calculate and display timing information | |
| # total_time = time.time() - script_start | |
| # processing_time = time.time() - processing_start | |
| # end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # print("\nTiming Summary:") | |
| # print(f"Start time: {start_time}") | |
| # print(f"End time: {end_time}") | |
| # print(f"Total runtime: {format_time(total_time)}") | |
| # print(f"Setup time: {format_time(processing_start - script_start)}") | |
| # print(f"Processing time: {format_time(processing_time)}") | |
| # print(f"Results written to: {args.output_questions}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions.") | |
| parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") | |
| parser.add_argument("--reference_corpus", default="baseline/train.json") | |
| parser.add_argument( | |
| "-i", | |
| "--top_k_target_knowledge", | |
| default="data_store/dev_reranking_top_k.json", | |
| help="Directory where the sentences for the scraped data is saved.", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output_questions", | |
| default="data_store/dev_top_k_qa.json", | |
| help="Directory where the sentences for the scraped data is saved.", | |
| ) | |
| parser.add_argument( | |
| "--top_k", | |
| default=10, | |
| type=int | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=4, | |
| help="Number of examples to process in each batch" | |
| ) | |
| parser.add_argument( | |
| "-e", | |
| "--end", | |
| type=int, | |
| default=-1 | |
| ) | |
| args = parser.parse_args() | |
| main(args) |