Spaces:
Runtime error
Runtime error
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import pipeline | |
# Load the pre-trained sentence transformer model | |
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens') | |
# Load the pre-trained extractive QA model | |
qa_pipeline = pipeline('question-answering', model='distilbert-base-cased-distilled-squad') | |
# Define the example documents and the example security-related question | |
documents = [ | |
"Data breaches are a common occurrence in today's digital world. Companies must take measures to protect sensitive information from unauthorized access or disclosure.", | |
"Phishing attacks are a type of cyberattack that use social engineering to trick users into divulging confidential information. Employees should be trained to recognize and avoid phishing scams.", | |
"The use of encryption can help prevent unauthorized access to data by encrypting it so that it can only be accessed by authorized users who have the decryption key.", | |
"A firewall is a network security system that monitors and controls incoming and outgoing network traffic based on predetermined security rules. It can help protect against unauthorized access to a network.", | |
"Access control is the process of limiting access to resources to only authorized users. It is an important aspect of information security and can be achieved through the use of authentication and authorization mechanisms.", | |
] | |
question = "What measures can companies take to protect sensitive information from unauthorized access or disclosure?" | |
# Define the function to process the inputs and return the answer | |
def answer_question(document, question): | |
# Generate embeddings for the documents | |
doc_embeddings = model.encode(documents) | |
# Generate the query embedding | |
query_embedding = model.encode([question]) | |
# Compute the cosine similarity between the query and the document embeddings | |
similarity_scores = cosine_similarity(query_embedding, doc_embeddings) | |
# Find the top 3 most similar documents | |
most_similar_idxs = similarity_scores.argsort()[0][-3:] | |
# Extract the answers from the top 3 most similar documents | |
answers = [] | |
for idx in most_similar_idxs: | |
answer = qa_pipeline({'context': documents[idx], 'question': question}) | |
answers.append(answer['answer']) | |
# Return the answers | |
return answers | |
# Define the input and output interfaces for the Gradio app | |
inputs = [ | |
gr.inputs.Textbox(label="Document"), | |
gr.inputs.Textbox(label="Question"), | |
] | |
outputs = gr.outputs.Textbox(label="Answer") | |
# Create the Gradio app | |
app = gr.Interface(fn=answer_question, inputs=inputs, outputs=outputs, | |
title="Security Question Answering", | |
description="Enter a security question and a document, and the app will return the answer based on the top 3 most similar documents.", | |
examples=[ | |
["Data breaches are a common occurrence in today's digital world. Companies must take measures to protect sensitive information from unauthorized access or disclosure.", "What measures can companies take to protect sensitive information?"], | |
["Phishing attacks are a type of cyberattack that use social engineering to trick users into divulging confidential information. Employees should be trained to recognize and avoid phishing scams.", "What is a phishing attack?"], | |
["The use of encryption can help prevent unauthorized access to data by encrypting it so that it can only be accessed by authorized users who have the decryption key.", "What is encryption?"], | |
["A firewall is a network security system that monitors and controls incoming and outgoing network traffic based on predetermined security rules. It can help protect against unauthorized access to a network.", "What is a firewall?"] | |
]) | |
# Run the app | |
app.launch() |