Spaces:
Runtime error
Runtime error
from collections import deque | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from sentence_transformers import SentenceTransformer | |
from utils import generate_response | |
import pandas as pd | |
import pickle | |
from utils import encode_rag, cosine_sim_rag, top_candidates | |
class ChatBot: | |
def __init__(self): | |
self.conversation_history = deque([], maxlen=10) | |
self.generative_model = None | |
self.generative_tokenizer = None | |
self.vect_data = [] | |
self.scripts = [] | |
self.ranking_model = None | |
def load(self): | |
""" "This method is called first to load all datasets and | |
model used by the chat bot; all the data to be saved in | |
tha data folder, models to be loaded from hugging face""" | |
with open("data/scripts_vectors.pkl", "rb") as fp: | |
self.vect_data = pickle.load(fp) | |
self.scripts = pd.read_pickle("data/scripts.pkl") | |
self.ranking_model = SentenceTransformer( | |
"Shakhovak/chatbot_sentence-transformer" | |
) | |
self.generative_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"Shakhovak/flan-t5-base-sheldon-chat-v2" | |
) | |
self.generative_tokenizer = AutoTokenizer.from_pretrained( | |
"Shakhovak/flan-t5-base-sheldon-chat-v2" | |
) | |
def generate_response(self, utterance): | |
query_encoding = encode_rag( | |
texts=utterance, | |
model=self.ranking_model, | |
contexts=self.conversation_history, | |
) | |
bot_cosine_scores = cosine_sim_rag( | |
self.vect_data, | |
query_encoding, | |
) | |
top_scores, top_indexes = top_candidates( | |
bot_cosine_scores, initial_data=self.scripts | |
) | |
if top_scores[0] >= 0.89: | |
for index in top_indexes: | |
rag_answer = self.scripts.iloc[index]["answer"] | |
answer = generate_response( | |
model=self.generative_model, | |
tokenizer=self.generative_tokenizer, | |
question=utterance, | |
context=self.conversation_history, | |
top_p=0.9, | |
temperature=0.95, | |
rag_answer=rag_answer, | |
) | |
else: | |
answer = generate_response( | |
model=self.generative_model, | |
tokenizer=self.generative_tokenizer, | |
question=utterance, | |
context=self.conversation_history, | |
top_p=0.9, | |
temperature=0.95, | |
) | |
self.conversation_history.append(utterance) | |
self.conversation_history.append(answer) | |
return answer | |
# katya = ChatBot() | |
# katya.load() | |
# print(katya.generate_response("What is he doing there?")) | |