dwhitena's picture
Add Application file
f14e709
raw history blame
No virus
2.49 kB
from statistics import mean
import random
import torch
from transformers import BertModel, BertTokenizerFast
import numpy as np
import torch.nn.functional as F
import gradio as gr
threshold = 0.4
tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
model = BertModel.from_pretrained("setu4993/LaBSE")
model = model.eval()
order_food_ex = [
"food",
"I am hungry, I want to order food",
"How do I order food",
"What are the food options",
"I need dinner",
"I want lunch",
"What are the menu options",
"I want a hamburger"
]
talk_to_human_ex = [
"I need to talk to someone",
"Connect me with a human",
"I need to speak with a person",
"Put me on with a human",
"Connect me with customer service",
"human"
]
def embed(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.pooler_output
def similarity(embeddings_1, embeddings_2):
normalized_embeddings_1 = F.normalize(embeddings_1, p=2)
normalized_embeddings_2 = F.normalize(embeddings_2, p=2)
return torch.matmul(
normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1)
)
order_food_embed = [embed(x, tokenizer, model) for x in order_food_ex]
talk_to_human_embed = [embed(x, tokenizer, model) for x in talk_to_human_ex]
def chat(message, history):
history = history or []
message_embed = embed(message, tokenizer, model)
order_sim = []
for em in order_food_embed:
order_sim.append(float(similarity(em, message_embed)))
human_sim = []
for em in talk_to_human_embed:
human_sim.append(float(similarity(em, message_embed)))
if mean(order_sim) > threshold:
response = random.choice([
"We have hamburgers or pizza! Which one do you want?",
"Do you want a hamburger or a pizza?"])
elif mean(human_sim) > threshold:
response = random.choice([
"Sure, a customer service agent will jump into this convo shortly!",
"No problem. Let me forward on this conversation to a person that can respond."])
else:
response = "Sorry, I didn't catch that. Could your rephrase?"
history.append((message, response))
return history, history
iface = gr.Interface(
chat,
["text", "state"],
["chatbot", "state"],
allow_screenshot=False,
allow_flagging="never",
)
iface.launch()