Spaces:
Runtime error
Runtime error
File size: 7,097 Bytes
31c7f82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import json
# Load model and tokenizer from HuggingFace
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import CrossEncoder
# from pyserini.search.lucene import LuceneSearcher
import pyterrier as pt
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker
if not pt.started():
pt.init()
import ipdb
def extract_context(json_data, number, turn_id):
# Find the correct dictionary with the given number
data = None
for item in json_data:
if item["number"] == number:
data = item
break
# If we couldn't find the data for the given number
if not data:
print("No data found for the given number.")
return "No data found for the given number.", None
# Extract the utterance and response values
texts = []
current_utterance = ""
for turn in data["turns"]:
if turn["turn_id"] < turn_id:
texts.append(turn["utterance"])
texts.append(turn["response"])
elif turn["turn_id"] == turn_id:
current_utterance = turn["utterance"]
texts.append(current_utterance)
# Join the texts with "|||" separator
context = "|||".join(texts)
return current_utterance, context
def escape_special_characters(query):
# Escaping special characters
special_chars = ["?", "&", "|", "!", "{", "}", "[", "]", "^", "~", "*", ":", '"', "+", "-", "(", ")"]
for char in special_chars:
query = query.replace(char, "")
return query
def str_to_df_query(query):
if isinstance(query, str):
query = escape_special_characters(query)
return pd.DataFrame([[1, query]], columns=["qid", "query"])
elif isinstance(query, list):
query = [escape_special_characters(q) for q in query]
return pd.DataFrame([[i + 1, q] for i, q in enumerate(query)], columns=["qid", "query"])
else:
raise ValueError("The query must be a string or a list of strings.")
def retrieve_and_rerank(query, pipeline):
query_df = str_to_df_query(query)
res = pipeline.transform(query_df)
candidate_set = []
for i, row in res.iterrows():
passage_id = row["docno"]
rank = row["rank"]
score = row["score"]
passage_text = row["text"]
candidate_set.append({"passage_id": passage_id, "rank": i + 1, "score": score, "passage_text": passage_text})
return candidate_set
def rerank_passages(query, passages, reranker):
res = []
query_passage_pairs = [[query, passage["passage_text"]] for passage in passages]
scores = reranker.predict(query_passage_pairs)
for passage, score in zip(passages, scores):
passage["reranker_score"] = score
res.append(passage)
ranked_passages = sorted(passages, key=lambda x: x["reranker_score"], reverse=True)
return ranked_passages
def rag(rewrite, top_n_passages=3):
# Set up
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set Up Index
index_path = os.path.join("/root/nfs/iKAT/2023/ikat_index/index_pyterrier_with_text", "data.properties")
index = pt.IndexFactory.of(index_path)
# Set up Pipeline for retrieval and reranking
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
monoT5 = MonoT5ReRanker()
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text")
# Passage retrieval and reranking
reranked_passages = retrieve_and_rerank(rewrite, pipeline)
passages = [{"passage_id": passage["passage_id"], "passage_text": passage["passage_text"]} for passage in reranked_passages][:top_n_passages]
return passages
def retrieve_passage(resolved_query, history, RAG, top_n_passages=3):
# TODO: RAG function
if RAG:
if len(history) >= 1:
rag_context = rag(resolved_query, top_n_passages)
else:
rag_context = rag(
resolved_query,
)
else:
rag_context = "No Context"
return rag_context
def get_length_without_special_tokens(text, tokenizer):
# Tokenize the prompt and get input IDs
inputs = tokenizer(text, return_tensors="pt")
# Extract the input IDs from the tokenized output
input_ids = inputs.input_ids[0]
# Decode the input IDs to a string, skipping special tokens
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
return len(decoded_text)
def response_generation(messages, model, tokenizer, device, terminators, max_tokens=512, temperature=0.0, top_p=0.9):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
eos_token_id=terminators,
do_sample=False, # Greedy_decoding to be deterministic
# temperature=temperature,
top_p=top_p,
)
prompt_length = get_length_without_special_tokens(prompt, tokenizer)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)[prompt_length:]
# ipdb.set_trace()
return response.strip(), messages + [{"role": "assistant", "content": response.strip()}]
if __name__ == "__main__":
# Set up
device = "cuda" if torch.cuda.is_available() else "cpu"
demo_path = "/nfs/primary/iKAT/2023/"
with open(os.path.join(demo_path, "ikat_demo/test.json"), "r") as f:
topics = json.load(f)
# Set up Index
index_path = os.path.join("/root/nfs/iKAT/2023/index_pyterrier_with_text", "data.properties")
index = pt.IndexFactory.of(index_path)
# Set up Pipeline for retrieval and reranking
bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
monoT5 = MonoT5ReRanker()
pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text")
query = "Can you compare mozzarella with plant-based cheese?"
# Query rewriting
rewriter = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard").to(device).eval()
rewriter_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
number_to_search = "10-1"
turn_id_to_search = 6
utterance, context = extract_context(topics, number_to_search, turn_id_to_search)
rewrite = rewrite_query(context, rewriter, rewriter_tokenizer, device)
# Passage Retrieval and Reranking
reranked_passages = retrieve_and_rerank(rewrite, pipeline)
# Response generation
summarizer = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
summarizer_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
# We use the top-3 reranked passages to generate a response
passages = [passage["passage_text"] for passage in reranked_passages][:3]
print(json.dumps(passages, indent=4))
responses = generate_response(passages, summarizer, summarizer_tokenizer)
print("Done")
|