File size: 7,097 Bytes
31c7f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import json

# Load model and tokenizer from HuggingFace
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import CrossEncoder

# from pyserini.search.lucene import LuceneSearcher
import pyterrier as pt
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker

if not pt.started():
    pt.init()
import ipdb


def extract_context(json_data, number, turn_id):
    # Find the correct dictionary with the given number
    data = None
    for item in json_data:
        if item["number"] == number:
            data = item
            break

    # If we couldn't find the data for the given number
    if not data:
        print("No data found for the given number.")
        return "No data found for the given number.", None

    # Extract the utterance and response values
    texts = []
    current_utterance = ""
    for turn in data["turns"]:
        if turn["turn_id"] < turn_id:
            texts.append(turn["utterance"])
            texts.append(turn["response"])
        elif turn["turn_id"] == turn_id:
            current_utterance = turn["utterance"]
            texts.append(current_utterance)

    # Join the texts with "|||" separator
    context = "|||".join(texts)

    return current_utterance, context


def escape_special_characters(query):
    # Escaping special characters
    special_chars = ["?", "&", "|", "!", "{", "}", "[", "]", "^", "~", "*", ":", '"', "+", "-", "(", ")"]
    for char in special_chars:
        query = query.replace(char, "")
    return query


def str_to_df_query(query):
    if isinstance(query, str):
        query = escape_special_characters(query)
        return pd.DataFrame([[1, query]], columns=["qid", "query"])
    elif isinstance(query, list):
        query = [escape_special_characters(q) for q in query]
        return pd.DataFrame([[i + 1, q] for i, q in enumerate(query)], columns=["qid", "query"])
    else:
        raise ValueError("The query must be a string or a list of strings.")


def retrieve_and_rerank(query, pipeline):
    query_df = str_to_df_query(query)
    res = pipeline.transform(query_df)
    candidate_set = []
    for i, row in res.iterrows():
        passage_id = row["docno"]
        rank = row["rank"]
        score = row["score"]
        passage_text = row["text"]
        candidate_set.append({"passage_id": passage_id, "rank": i + 1, "score": score, "passage_text": passage_text})
    return candidate_set


def rerank_passages(query, passages, reranker):
    res = []
    query_passage_pairs = [[query, passage["passage_text"]] for passage in passages]
    scores = reranker.predict(query_passage_pairs)

    for passage, score in zip(passages, scores):
        passage["reranker_score"] = score
        res.append(passage)

    ranked_passages = sorted(passages, key=lambda x: x["reranker_score"], reverse=True)
    return ranked_passages


def rag(rewrite, top_n_passages=3):
    # Set up
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Set Up Index
    index_path = os.path.join("/root/nfs/iKAT/2023/ikat_index/index_pyterrier_with_text", "data.properties")
    index = pt.IndexFactory.of(index_path)
    # Set up Pipeline for retrieval and reranking
    bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
    monoT5 = MonoT5ReRanker()
    pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text")
    # Passage retrieval and reranking
    reranked_passages = retrieve_and_rerank(rewrite, pipeline)
    passages = [{"passage_id": passage["passage_id"], "passage_text": passage["passage_text"]} for passage in reranked_passages][:top_n_passages]
    return passages


def retrieve_passage(resolved_query, history, RAG, top_n_passages=3):
    # TODO: RAG function
    if RAG:
        if len(history) >= 1:
            rag_context = rag(resolved_query, top_n_passages)
        else:
            rag_context = rag(
                resolved_query,
            )
    else:
        rag_context = "No Context"
    return rag_context


def get_length_without_special_tokens(text, tokenizer):
    # Tokenize the prompt and get input IDs
    inputs = tokenizer(text, return_tensors="pt")
    # Extract the input IDs from the tokenized output
    input_ids = inputs.input_ids[0]
    # Decode the input IDs to a string, skipping special tokens
    decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)

    return len(decoded_text)


def response_generation(messages, model, tokenizer, device, terminators, max_tokens=512, temperature=0.0, top_p=0.9):
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        eos_token_id=terminators,
        do_sample=False,  # Greedy_decoding to be deterministic
        # temperature=temperature,
        top_p=top_p,
    )

    prompt_length = get_length_without_special_tokens(prompt, tokenizer)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)[prompt_length:]
    # ipdb.set_trace()
    return response.strip(), messages + [{"role": "assistant", "content": response.strip()}]


if __name__ == "__main__":
    # Set up
    device = "cuda" if torch.cuda.is_available() else "cpu"
    demo_path = "/nfs/primary/iKAT/2023/"
    with open(os.path.join(demo_path, "ikat_demo/test.json"), "r") as f:
        topics = json.load(f)

    # Set up Index
    index_path = os.path.join("/root/nfs/iKAT/2023/index_pyterrier_with_text", "data.properties")
    index = pt.IndexFactory.of(index_path)

    # Set up Pipeline for retrieval and reranking
    bm25 = pt.BatchRetrieve(index, wmodel="BM25", metadata=["docno", "text"])
    monoT5 = MonoT5ReRanker()
    pipeline = (bm25 % 10) >> pt.text.get_text(index, "text") >> (monoT5 % 5) >> pt.text.get_text(index, "text")

    query = "Can you compare mozzarella with plant-based cheese?"

    # Query rewriting
    rewriter = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard").to(device).eval()
    rewriter_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
    number_to_search = "10-1"
    turn_id_to_search = 6
    utterance, context = extract_context(topics, number_to_search, turn_id_to_search)
    rewrite = rewrite_query(context, rewriter, rewriter_tokenizer, device)

    # Passage Retrieval and Reranking
    reranked_passages = retrieve_and_rerank(rewrite, pipeline)

    # Response generation
    summarizer = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
    summarizer_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
    # We use the top-3 reranked passages to generate a response
    passages = [passage["passage_text"] for passage in reranked_passages][:3]
    print(json.dumps(passages, indent=4))
    responses = generate_response(passages, summarizer, summarizer_tokenizer)
    print("Done")