from statistics import mean import random import torch from transformers import BertModel, BertTokenizerFast import numpy as np import torch.nn.functional as F import gradio as gr threshold = 0.4 tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE") model = BertModel.from_pretrained("setu4993/LaBSE") model = model.eval() order_food_ex = [ "food", "I am hungry, I want to order food", "How do I order food", "What are the food options", "I need dinner", "I want lunch", "What are the menu options", "I want a hamburger" ] talk_to_human_ex = [ "I need to talk to someone", "Connect me with a human", "I need to speak with a person", "Put me on with a human", "Connect me with customer service", "human" ] def embed(text, tokenizer, model): inputs = tokenizer(text, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) return outputs.pooler_output def similarity(embeddings_1, embeddings_2): normalized_embeddings_1 = F.normalize(embeddings_1, p=2) normalized_embeddings_2 = F.normalize(embeddings_2, p=2) return torch.matmul( normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1) ) order_food_embed = [embed(x, tokenizer, model) for x in order_food_ex] talk_to_human_embed = [embed(x, tokenizer, model) for x in talk_to_human_ex] def chat(message, history): history = history or [] message_embed = embed(message, tokenizer, model) order_sim = [] for em in order_food_embed: order_sim.append(float(similarity(em, message_embed))) human_sim = [] for em in talk_to_human_embed: human_sim.append(float(similarity(em, message_embed))) if mean(order_sim) > threshold: response = random.choice([ "We have hamburgers or pizza! Which one do you want?", "Do you want a hamburger or a pizza?"]) elif mean(human_sim) > threshold: response = random.choice([ "Sure, a customer service agent will jump into this convo shortly!", "No problem. Let me forward on this conversation to a person that can respond."]) else: response = "Sorry, I didn't catch that. Could your rephrase?" history.append((message, response)) return history, history iface = gr.Interface( chat, ["text", "state"], ["chatbot", "state"], allow_screenshot=False, allow_flagging="never", ) iface.launch()