CelebChat / celebbot.py
lhzstar
new commits
56e0560
raw
history blame
No virus
6.49 kB
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import os
import json
import speech_recognition as sr
import re
import time
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import pickle
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
import run_tts
# Build the AI
class CelebBot():
def __init__(self, name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents):
self.name = name
print("--- starting up", self.name, "---")
self.text = ""
self.QA_tokenizer = QA_tokenizer
self.QA_model = QA_model
self.sentTr_tokenizer = sentTr_tokenizer
self.sentTr_model = sentTr_model
self.spacy_model = spacy_model
self.all_knowledge = knowledge_sents
def speech_to_text(self):
recognizer = sr.Recognizer()
with sr.Microphone() as mic:
recognizer.adjust_for_ambient_noise(mic, duration=1)
# flag = input("Are you ready to record?\nProceed (Y/n)")
# try:
# assert flag=='Y'
# except:
# self.text = ""
# print(f"me --> Permission denied")
time.sleep(1)
print("listening")
audio = recognizer.listen(mic)
try:
self.text = recognizer.recognize_google(audio)
except:
self.text = ""
print(f"me --> No audio recognized")
def wake_up(self, text):
return True if "hey " + self.name in text.lower() else False
def text_to_speech(self, autoplay=True):
return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay)
def sentence_embeds_inference(self, texts: list):
def _mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Tokenize sentences
encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
encoded_input["input_ids"] = encoded_input["input_ids"]
encoded_input["attention_mask"] = encoded_input["attention_mask"]
# Compute token embeddings
with torch.no_grad():
model_output = self.sentTr_model(**encoded_input)
# Perform pooling
sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def retrieve_knowledge_assertions(self):
question_embeddings = self.sentence_embeds_inference([self.name + ', ' + self.text])
all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge)
similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu())
similarity = np.reshape(similarity, (1, -1))[0]
K = min(8, len(self.all_knowledge))
top_K = np.sort(np.argpartition(similarity, -K)[-K: ])
all_knowledge_assertions = np.array(self.all_knowledge)[top_K]
# similarities = np.array(similarity)[top_K]
# print(*list(zip(all_knowledge_assertions, similarities)), sep='\n')
return ' '.join(all_knowledge_assertions)
def question_answer(self, instruction1='', knowledge=''):
if self.text != "":
## wake up
if self.wake_up(self.text) is True:
self.text = f"Hello I am {self.name} the AI, what can I do for you?"
## have a conversation
else:
if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:
instruction1 = f"You are a celebrity named {self.name}. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
knowledge = self.retrieve_knowledge_assertions()
else:
instruction1 = f"Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
query = f"Context: {instruction1} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
outputs = self.QA_model.generate(input_ids, max_length=1024)
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
# instruction2 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge'
# query = f"{instruction2} [knowledge] {self.text} {answer} [question] {self.name}, {self.text}"
# input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
# outputs = self.QA_model.generate(input_ids, max_length=1024)
# self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
return self.text
@staticmethod
def action_time():
return f"it's {datetime.datetime.now().time().strftime('%H:%M')}"
@staticmethod
def save_kb(kb, filename):
with open(filename, "wb") as f:
pickle.dump(kb, f)
@staticmethod
def load_kb(filename):
res = None
with open(filename, "rb") as f:
res = pickle.load(f)
return res