answer-machine / app.py
johnnyfivefingers's picture
Update app.py
739e3b5
import gradio as gr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import pipeline
# Load the pre-trained sentence transformer model
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
# Load the pre-trained extractive QA model
qa_pipeline = pipeline('question-answering', model='distilbert-base-cased-distilled-squad')
# Define the example documents and the example security-related question
documents = [
"Data breaches are a common occurrence in today's digital world. Companies must take measures to protect sensitive information from unauthorized access or disclosure.",
"Phishing attacks are a type of cyberattack that use social engineering to trick users into divulging confidential information. Employees should be trained to recognize and avoid phishing scams.",
"The use of encryption can help prevent unauthorized access to data by encrypting it so that it can only be accessed by authorized users who have the decryption key.",
"A firewall is a network security system that monitors and controls incoming and outgoing network traffic based on predetermined security rules. It can help protect against unauthorized access to a network.",
"Access control is the process of limiting access to resources to only authorized users. It is an important aspect of information security and can be achieved through the use of authentication and authorization mechanisms.",
]
question = "What measures can companies take to protect sensitive information from unauthorized access or disclosure?"
# Define the function to process the inputs and return the answer
def answer_question(document, question):
# Generate embeddings for the documents
doc_embeddings = model.encode(documents)
# Generate the query embedding
query_embedding = model.encode([question])
# Compute the cosine similarity between the query and the document embeddings
similarity_scores = cosine_similarity(query_embedding, doc_embeddings)
# Find the top 3 most similar documents
most_similar_idxs = similarity_scores.argsort()[0][-3:]
# Extract the answers from the top 3 most similar documents
answers = []
for idx in most_similar_idxs:
answer = qa_pipeline({'context': documents[idx], 'question': question})
answers.append(answer['answer'])
# Return the answers
return answers
# Define the input and output interfaces for the Gradio app
inputs = [
gr.inputs.Textbox(label="Document"),
gr.inputs.Textbox(label="Question"),
]
outputs = gr.outputs.Textbox(label="Answer")
# Create the Gradio app
app = gr.Interface(fn=answer_question, inputs=inputs, outputs=outputs,
title="Security Question Answering",
description="Enter a security question and a document, and the app will return the answer based on the top 3 most similar documents.",
examples=[
["Data breaches are a common occurrence in today's digital world. Companies must take measures to protect sensitive information from unauthorized access or disclosure.", "What measures can companies take to protect sensitive information?"],
["Phishing attacks are a type of cyberattack that use social engineering to trick users into divulging confidential information. Employees should be trained to recognize and avoid phishing scams.", "What is a phishing attack?"],
["The use of encryption can help prevent unauthorized access to data by encrypting it so that it can only be accessed by authorized users who have the decryption key.", "What is encryption?"],
["A firewall is a network security system that monitors and controls incoming and outgoing network traffic based on predetermined security rules. It can help protect against unauthorized access to a network.", "What is a firewall?"]
])
# Run the app
app.launch()