CelebChat / celebbot.py
lhzstar
initial commits
6bc94ac
raw
history blame
No virus
6.05 kB
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import os
import json
import speech_recognition as sr
import re
import time
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import pickle
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
import run_tts
# Build the AI
class CelebBot():
def __init__(self, name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents):
self.name = name
print("--- starting up", self.name, "---")
self.text = ""
self.QA_tokenizer = QA_tokenizer
self.QA_model = QA_model
self.sentTr_tokenizer = sentTr_tokenizer
self.sentTr_model = sentTr_model
self.spacy_model = spacy_model
self.all_knowledge = knowledge_sents
@st.cache_resource
def get_seq2seq_model(self, _model_id):
return AutoModelForSeq2SeqLM.from_pretrained(_model_id)
@st.cache_resource
def get_model(self,_model_id):
return AutoModel.from_pretrained(_model_id)
@st.cache_resource
def get_tokenizer(self,_model_id):
return AutoTokenizer.from_pretrained(_model_id)
def speech_to_text(self):
recognizer = sr.Recognizer()
with sr.Microphone() as mic:
recognizer.adjust_for_ambient_noise(mic, duration=1)
# flag = input("Are you ready to record?\nProceed (Y/n)")
# try:
# assert flag=='Y'
# except:
# self.text = ""
# print(f"me --> Permission denied")
time.sleep(1)
print("listening")
audio = recognizer.listen(mic)
try:
self.text = recognizer.recognize_google(audio)
except:
self.text = ""
print(f"me --> No audio recognized")
def wake_up(self, text):
return True if "hey " + self.name in text.lower() else False
def text_to_speech(self, autoplay=True):
return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay)
def sentence_embeds_inference(self, texts: list):
def _mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Tokenize sentences
encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
encoded_input["input_ids"] = encoded_input["input_ids"]
encoded_input["attention_mask"] = encoded_input["attention_mask"]
# Compute token embeddings
with torch.no_grad():
model_output = self.sentTr_model(**encoded_input)
# Perform pooling
sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def retrieve_knowledge_assertions(self):
question_embeddings = self.sentence_embeds_inference([self.name + ', ' + self.text])
all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge)
similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu())
similarity = np.reshape(similarity, (1, -1))[0]
K = min(8, len(self.all_knowledge))
top_K = np.sort(np.argpartition(similarity, -K)[-K: ])
all_knowledge_assertions = np.array(self.all_knowledge)[top_K]
# similarities = np.array(similarity)[top_K]
# print(*list(zip(all_knowledge_assertions, similarities)), sep='\n')
return ' '.join(all_knowledge_assertions)
def question_answer(self, instruction1='', knowledge=''):
if self.text != "":
## wake up
if self.wake_up(self.text) is True:
self.text = f"Hello I am {self.name} the AI, what can I do for you?"
## have a conversation
else:
# if re.search(you_regex, self.text) != None:
instruction1 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'
knowledge = self.retrieve_knowledge_assertions()
# else:
# instruction1 = f'[Instruction] You need to answer the question based on commonsense.'
query = f"{instruction1} [knowledge] {knowledge} [question] {self.text} {self.name}!"
input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
outputs = self.QA_model.generate(input_ids, max_length=1024)
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
# instruction2 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge'
# query = f"{instruction2} [knowledge] {self.text} {answer} [question] {self.name}, {self.text}"
# input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
# outputs = self.QA_model.generate(input_ids, max_length=1024)
# self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
return self.text
@staticmethod
def action_time():
return f"it's {datetime.datetime.now().time().strftime('%H:%M')}"
@staticmethod
def save_kb(kb, filename):
with open(filename, "wb") as f:
pickle.dump(kb, f)
@staticmethod
def load_kb(filename):
res = None
with open(filename, "rb") as f:
res = pickle.load(f)
return res