Spaces:
Sleeping
Sleeping
import os | |
import json | |
# Load model and tokenizer from HuggingFace | |
import numpy as np | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from sentence_transformers import CrossEncoder | |
# from pyserini.search.lucene import LuceneSearcher | |
import pyterrier as pt | |
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker | |
if not pt.started(): | |
pt.init() | |
import ipdb | |
def extract_context(json_data, number, turn_id): | |
# Find the correct dictionary with the given number | |
data = None | |
for item in json_data: | |
if item["number"] == number: | |
data = item | |
break | |
# If we couldn't find the data for the given number | |
if not data: | |
print("No data found for the given number.") | |
return "No data found for the given number.", None | |
# Extract the utterance and response values | |
texts = [] | |
current_utterance = "" | |
for turn in data["turns"]: | |
if turn["turn_id"] < turn_id: | |
texts.append(turn["utterance"]) | |
texts.append(turn["response"]) | |
elif turn["turn_id"] == turn_id: | |
current_utterance = turn["utterance"] | |
texts.append(current_utterance) | |
# Join the texts with "|||" separator | |
context = "|||".join(texts) | |
return current_utterance, context | |
def escape_special_characters(query): | |
# Escaping special characters | |
special_chars = ["?", "&", "|", "!", "{", "}", "[", "]", "^", "~", "*", ":", '"', "+", "-", "(", ")"] | |
for char in special_chars: | |
query = query.replace(char, "") | |
return query | |
def str_to_df_query(query): | |
if isinstance(query, str): | |
query = escape_special_characters(query) | |
return pd.DataFrame([[1, query]], columns=["qid", "query"]) | |
elif isinstance(query, list): | |
query = [escape_special_characters(q) for q in query] | |
return pd.DataFrame([[i + 1, q] for i, q in enumerate(query)], columns=["qid", "query"]) | |
else: | |
raise ValueError("The query must be a string or a list of strings.") | |
def retrieve_and_rerank(query, pipeline): | |
query_df = str_to_df_query(query) | |
res = pipeline.transform(query_df) | |
candidate_set = [] | |
for i, row in res.iterrows(): | |
passage_id = row["docno"] | |
rank = row["rank"] | |
score = row["score"] | |
passage_text = row["text"] | |
candidate_set.append({"passage_id": passage_id, "rank": i + 1, "score": score, "passage_text": passage_text}) | |
return candidate_set | |
def rerank_passages(query, passages, reranker): | |
res = [] | |
query_passage_pairs = [[query, passage["passage_text"]] for passage in passages] | |
scores = reranker.predict(query_passage_pairs) | |
for passage, score in zip(passages, scores): | |
passage["reranker_score"] = score | |
res.append(passage) | |
ranked_passages = sorted(passages, key=lambda x: x["reranker_score"], reverse=True) | |
return ranked_passages | |
def rag(rewrite, top_n_passages=3): | |
# Set up | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Set Up Index | |
index_path = os.path.join("/root/nfs/iKAT/2023/ikat_index/index_pyterrier_with_text", "data.properties") | |
index = pt.IndexFactory.of(index_path) | |
# Set up Pipeline for retrieval and reranking | |
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"]) | |
monoT5 = MonoT5ReRanker() | |
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text") | |
# Passage retrieval and reranking | |
reranked_passages = retrieve_and_rerank(rewrite, pipeline) | |
passages = [{"passage_id": passage["passage_id"], "passage_text": passage["passage_text"]} for passage in reranked_passages][:top_n_passages] | |
return passages | |
def retrieve_passage(resolved_query, history, RAG, top_n_passages=3): | |
# TODO: RAG function | |
if RAG: | |
if len(history) >= 1: | |
rag_context = rag(resolved_query, top_n_passages) | |
else: | |
rag_context = rag( | |
resolved_query, | |
) | |
else: | |
rag_context = "No Context" | |
return rag_context | |
def get_length_without_special_tokens(text, tokenizer): | |
# Tokenize the prompt and get input IDs | |
inputs = tokenizer(text, return_tensors="pt") | |
# Extract the input IDs from the tokenized output | |
input_ids = inputs.input_ids[0] | |
# Decode the input IDs to a string, skipping special tokens | |
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True) | |
return len(decoded_text) | |
def response_generation(messages, model, tokenizer, device, terminators, max_tokens=512, temperature=0.0, top_p=0.9): | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
eos_token_id=terminators, | |
do_sample=False, # Greedy_decoding to be deterministic | |
# temperature=temperature, | |
top_p=top_p, | |
) | |
prompt_length = get_length_without_special_tokens(prompt, tokenizer) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)[prompt_length:] | |
# ipdb.set_trace() | |
return response.strip(), messages + [{"role": "assistant", "content": response.strip()}] | |
if __name__ == "__main__": | |
# Set up | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
demo_path = "/nfs/primary/iKAT/2023/" | |
with open(os.path.join(demo_path, "ikat_demo/test.json"), "r") as f: | |
topics = json.load(f) | |
# Set up Index | |
index_path = os.path.join("/root/nfs/iKAT/2023/index_pyterrier_with_text", "data.properties") | |
index = pt.IndexFactory.of(index_path) | |
# Set up Pipeline for retrieval and reranking | |
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"]) | |
monoT5 = MonoT5ReRanker() | |
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text") | |
query = "Can you compare mozzarella with plant-based cheese?" | |
# Query rewriting | |
rewriter = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard").to(device).eval() | |
rewriter_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard") | |
number_to_search = "10-1" | |
turn_id_to_search = 6 | |
utterance, context = extract_context(topics, number_to_search, turn_id_to_search) | |
rewrite = rewrite_query(context, rewriter, rewriter_tokenizer, device) | |
# Passage Retrieval and Reranking | |
reranked_passages = retrieve_and_rerank(rewrite, pipeline) | |
# Response generation | |
summarizer = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") | |
summarizer_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") | |
# We use the top-3 reranked passages to generate a response | |
passages = [passage["passage_text"] for passage in reranked_passages][:3] | |
print(json.dumps(passages, indent=4)) | |
responses = generate_response(passages, summarizer, summarizer_tokenizer) | |
print("Done") | |