import gradio as gr import pickle import numpy as np import glob import tqdm import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, set_seed from peft import PeftModel import logging import os import json import spaces import ir_datasets import pytrec_eval from huggingface_hub import login import transformers import peft import faiss import sys from collections import defaultdict set_seed(42) # Set up logging # Set up logging with time printing logging.basicConfig( format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger(__name__) # Authenticate with HF_TOKEN login(token=os.environ['HF_TOKEN']) # Global variables CUR_MODEL = "Samaya-AI/Promptriever-Llama2-v1" BASE_MODEL = "meta-llama/Llama-2-7b-hf" tokenizer = None model = None retrievers = {} corpus_lookups = {} queries = {} q_lookups = {} qrels = {} query2qid = {} datasets = ["scifact"] current_dataset = "scifact" faiss_index = None def log_system_info(): logger.info("System Information:") logger.info(f"Python version: {sys.version}") logger.info("\nPackage Versions:") logger.info(f"torch: {torch.__version__}") logger.info(f"transformers: {transformers.__version__}") logger.info(f"peft: {peft.__version__}") logger.info(f"faiss: {faiss.__version__}") logger.info(f"gradio: {gr.__version__}") logger.info(f"ir_datasets: {ir_datasets.__version__}") if torch.cuda.is_available(): logger.info(f"\nCUDA Information:") logger.info(f"CUDA available: Yes") logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"cuDNN version: {torch.backends.cudnn.version()}") logger.info(f"Number of GPUs: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") else: logger.info("\nCUDA Information:") logger.info("CUDA available: No") log_system_info() def pool(last_hidden_states, attention_mask, pool_type="last"): last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) if pool_type == "last": left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: emb = last_hidden[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden.shape[0] emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] else: raise ValueError(f"pool_type {pool_type} not supported") return emb def create_batch_dict(tokenizer, input_texts, always_add_eos="last", max_length=512): batch_dict = tokenizer( input_texts, max_length=max_length - 1, return_token_type_ids=False, return_attention_mask=False, padding=False, truncation=True ) if always_add_eos == "last": batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] return tokenizer.pad( batch_dict, padding=True, pad_to_multiple_of=8, return_attention_mask=True, return_tensors="pt", ) class RepLlamaModel: def __init__(self, model_name_or_path): self.base_model = "meta-llama/Llama-2-7b-hf" self.tokenizer = AutoTokenizer.from_pretrained(self.base_model) self.tokenizer.model_max_length = 2048 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" self.model = self.get_model(model_name_or_path) self.model.config.max_length = 2048 def get_model(self, peft_model_name): base_model = AutoModel.from_pretrained(self.base_model) model = PeftModel.from_pretrained(base_model, peft_model_name) model = model.merge_and_unload() model.eval() return model def encode(self, texts, batch_size=48, **kwargs): # if model is not on cuda, put it there if self.model.device.type != "cuda": self.model = self.model.cuda() all_embeddings = [] for i in tqdm.tqdm(range(0, len(texts), batch_size)): batch_texts = texts[i:i+batch_size] batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last") batch_dict = {key: value.cuda() for key, value in batch_dict.items()} with torch.cuda.amp.autocast(): with torch.no_grad(): outputs = self.model(**batch_dict) embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last') embeddings = F.normalize(embeddings, p=2, dim=-1) logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}") all_embeddings.append(embeddings.cpu().numpy()) # self.model = self.model.cpu() return np.concatenate(all_embeddings, axis=0) def load_corpus_embeddings(dataset_name): corpus_path = f"{dataset_name}/corpus_emb.*.pkl" index_files = glob.glob(corpus_path) index_files.sort(key=lambda x: int(x.split('.')[-2])) all_embeddings = [] corpus_lookups = [] for file in index_files: with open(file, 'rb') as f: embeddings, p_lookup = pickle.load(f) all_embeddings.append(embeddings) corpus_lookups.extend(p_lookup) all_embeddings = np.concatenate(all_embeddings, axis=0) logger.info(f"Loaded corpus embeddings for {dataset_name}. Shape: {all_embeddings.shape}") return all_embeddings, corpus_lookups def create_faiss_index(embeddings): dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) index.add(embeddings) logger.info(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}") return index def load_or_create_faiss_index(dataset_name): embeddings, corpus_lookups = load_corpus_embeddings(dataset_name) index = create_faiss_index(embeddings) return index, corpus_lookups def initialize_faiss_and_corpus(dataset_name): global corpus_lookups index, corpus_lookups[dataset_name] = load_or_create_faiss_index(dataset_name) logger.info(f"Initialized FAISS index and corpus lookups for {dataset_name}") return index def search_queries(dataset_name, q_reps, depth=100): global faiss_index logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}") # Perform the search all_scores, all_indices = faiss_index.search(q_reps, depth) logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}") logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}") psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices] return all_scores, np.array(psg_indices) def load_queries(dataset_name): global queries, q_lookups, qrels, query2qid dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else "")) queries[dataset_name] = [] query2qid[dataset_name] = defaultdict(dict) q_lookups[dataset_name] = {} qrels[dataset_name] = {} for query in dataset.queries_iter(): queries[dataset_name].append(query.text) q_lookups[dataset_name][query.query_id] = query.text query2qid[dataset_name][query.text] = query.query_id for qrel in dataset.qrels_iter(): if qrel.query_id not in qrels[dataset_name]: qrels[dataset_name][qrel.query_id] = {} qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}") logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}") def evaluate(qrels, results, k_values): qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()} results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()} evaluator = pytrec_eval.RelevanceEvaluator( qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values} ) scores = evaluator.evaluate(results) metrics = {} for k in k_values: ndcg_scores = [query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()] recall_scores = [query_scores[f"recall_{k}"] for query_scores in scores.values()] metrics[f"NDCG@{k}"] = round(np.mean(ndcg_scores), 3) metrics[f"Recall@{k}"] = round(np.mean(recall_scores), 3) logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}") logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}") # delete nDCG@100 and Recall@10 del metrics["NDCG@100"] del metrics["Recall@100"] return metrics @spaces.GPU def run_evaluation(dataset, postfix): global current_dataset, queries, model, query2qid current_dataset = dataset input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]] logger.info(f"Number of input texts: {len(input_texts)}") logger.info(f"Sample input text: {input_texts[0]}") q_reps = model.encode(input_texts) logger.info(f"Encoded query first five: {q_reps[0][:5]}") logger.info(f"Encoded query representations shape: {q_reps.shape}") all_scores, psg_indices = search_queries(dataset, q_reps) results = {} logging.info(f"Number of queries in q_lookups: {len(q_lookups[dataset])}") logging.info("Size of all_scores: " + str(len(all_scores))) logging.info("Size of psg_indices: " + str(len(psg_indices))) for query, scores, doc_ids in zip(queries[current_dataset], all_scores, psg_indices): qid = query2qid[dataset][query] qid_str = str(qid) results[qid_str] = {} for doc_id, score in zip(doc_ids, scores): doc_id_str = str(doc_id) results[qid_str][doc_id_str] = float(score) if not results[qid_str]: # If no results for this query logger.warning(f"No results for query {qid_str}") logger.info(f"Number of queries in results: {len(results)}") logger.info(f"Sample result: {next(iter(results.items()))}") qrels[dataset] = {str(qid): {str(doc_id): rel for doc_id, rel in rels.items()} for qid, rels in qrels[dataset].items()} logger.info(f"Number of results: {len(results)}") logger.info(f"Sample result: {list(results.items())[0]}") logger.info(f"Number of queries in qrels: {len(qrels[dataset])}") logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}") logger.info(f"Number of queries in results: {len(results)}") logger.info(f"Sample result: {list(results.items())[0]}") # Check for mismatches qrels_keys = set(qrels[dataset].keys()) results_keys = set(results.keys()) logger.info(f"Queries in qrels but not in results: {qrels_keys - results_keys}") logger.info(f"Queries in results but not in qrels: {results_keys - qrels_keys}") metrics = evaluate(qrels[dataset], results, k_values=[10, 100]) return metrics @spaces.GPU def gradio_interface(dataset, postfix): return run_evaluation(dataset, postfix) if model is None: model = RepLlamaModel(model_name_or_path=CUR_MODEL) load_queries(current_dataset) faiss_index = initialize_faiss_and_corpus(current_dataset) # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Dropdown(choices=datasets, label="Dataset", value="scifact"), gr.Textbox(label="Prompt") ], outputs=gr.JSON(label="Evaluation Results"), title="Promptriever Demo", description="Enter a prompt to evaluate the model's performance on SciFact. Note: it takes between **10-30 seconds** to evaluate.", examples=[ ["scifact", ""], ["scifact", "Think carefully about these conditions when determining relevance"] ], cache_examples=False, ) # Launch the interface iface.launch(share=False)