|
import os |
|
import json |
|
|
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
import pyterrier as pt |
|
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker |
|
|
|
if not pt.started(): |
|
pt.init() |
|
import ipdb |
|
|
|
|
|
def extract_context(json_data, number, turn_id): |
|
|
|
data = None |
|
for item in json_data: |
|
if item["number"] == number: |
|
data = item |
|
break |
|
|
|
|
|
if not data: |
|
print("No data found for the given number.") |
|
return "No data found for the given number.", None |
|
|
|
|
|
texts = [] |
|
current_utterance = "" |
|
for turn in data["turns"]: |
|
if turn["turn_id"] < turn_id: |
|
texts.append(turn["utterance"]) |
|
texts.append(turn["response"]) |
|
elif turn["turn_id"] == turn_id: |
|
current_utterance = turn["utterance"] |
|
texts.append(current_utterance) |
|
|
|
|
|
context = "|||".join(texts) |
|
|
|
return current_utterance, context |
|
|
|
|
|
def escape_special_characters(query): |
|
|
|
special_chars = ["?", "&", "|", "!", "{", "}", "[", "]", "^", "~", "*", ":", '"', "+", "-", "(", ")"] |
|
for char in special_chars: |
|
query = query.replace(char, "") |
|
return query |
|
|
|
|
|
def str_to_df_query(query): |
|
if isinstance(query, str): |
|
query = escape_special_characters(query) |
|
return pd.DataFrame([[1, query]], columns=["qid", "query"]) |
|
elif isinstance(query, list): |
|
query = [escape_special_characters(q) for q in query] |
|
return pd.DataFrame([[i + 1, q] for i, q in enumerate(query)], columns=["qid", "query"]) |
|
else: |
|
raise ValueError("The query must be a string or a list of strings.") |
|
|
|
|
|
def retrieve_and_rerank(query, pipeline): |
|
query_df = str_to_df_query(query) |
|
res = pipeline.transform(query_df) |
|
candidate_set = [] |
|
for i, row in res.iterrows(): |
|
passage_id = row["docno"] |
|
rank = row["rank"] |
|
score = row["score"] |
|
passage_text = row["text"] |
|
candidate_set.append({"passage_id": passage_id, "rank": i + 1, "score": score, "passage_text": passage_text}) |
|
return candidate_set |
|
|
|
|
|
def rerank_passages(query, passages, reranker): |
|
res = [] |
|
query_passage_pairs = [[query, passage["passage_text"]] for passage in passages] |
|
scores = reranker.predict(query_passage_pairs) |
|
|
|
for passage, score in zip(passages, scores): |
|
passage["reranker_score"] = score |
|
res.append(passage) |
|
|
|
ranked_passages = sorted(passages, key=lambda x: x["reranker_score"], reverse=True) |
|
return ranked_passages |
|
|
|
|
|
def rag(rewrite, top_n_passages=3): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
index_path = os.path.join("/root/nfs/iKAT/2023/ikat_index/index_pyterrier_with_text", "data.properties") |
|
index = pt.IndexFactory.of(index_path) |
|
|
|
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"]) |
|
monoT5 = MonoT5ReRanker() |
|
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text") |
|
|
|
reranked_passages = retrieve_and_rerank(rewrite, pipeline) |
|
passages = [{"passage_id": passage["passage_id"], "passage_text": passage["passage_text"]} for passage in reranked_passages][:top_n_passages] |
|
return passages |
|
|
|
|
|
def retrieve_passage(resolved_query, history, RAG, top_n_passages=3): |
|
|
|
if RAG: |
|
if len(history) >= 1: |
|
rag_context = rag(resolved_query, top_n_passages) |
|
else: |
|
rag_context = rag( |
|
resolved_query, |
|
) |
|
else: |
|
rag_context = "No Context" |
|
return rag_context |
|
|
|
|
|
def get_length_without_special_tokens(text, tokenizer): |
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
input_ids = inputs.input_ids[0] |
|
|
|
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True) |
|
|
|
return len(decoded_text) |
|
|
|
|
|
def response_generation(messages, model, tokenizer, device, terminators, max_tokens=512, temperature=0.0, top_p=0.9): |
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
eos_token_id=terminators, |
|
do_sample=False, |
|
|
|
top_p=top_p, |
|
) |
|
|
|
prompt_length = get_length_without_special_tokens(prompt, tokenizer) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)[prompt_length:] |
|
|
|
return response.strip(), messages + [{"role": "assistant", "content": response.strip()}] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
demo_path = "/nfs/primary/iKAT/2023/" |
|
with open(os.path.join(demo_path, "ikat_demo/test.json"), "r") as f: |
|
topics = json.load(f) |
|
|
|
|
|
index_path = os.path.join("/root/nfs/iKAT/2023/index_pyterrier_with_text", "data.properties") |
|
index = pt.IndexFactory.of(index_path) |
|
|
|
|
|
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"]) |
|
monoT5 = MonoT5ReRanker() |
|
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text") |
|
|
|
query = "Can you compare mozzarella with plant-based cheese?" |
|
|
|
|
|
rewriter = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard").to(device).eval() |
|
rewriter_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard") |
|
number_to_search = "10-1" |
|
turn_id_to_search = 6 |
|
utterance, context = extract_context(topics, number_to_search, turn_id_to_search) |
|
rewrite = rewrite_query(context, rewriter, rewriter_tokenizer, device) |
|
|
|
|
|
reranked_passages = retrieve_and_rerank(rewrite, pipeline) |
|
|
|
|
|
summarizer = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") |
|
summarizer_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") |
|
|
|
passages = [passage["passage_text"] for passage in reranked_passages][:3] |
|
print(json.dumps(passages, indent=4)) |
|
responses = generate_response(passages, summarizer, summarizer_tokenizer) |
|
print("Done") |
|
|