Guardian-AI / app.py
princemaxp's picture
Update app.py
fbe73ed verified
# app.py
import os
import requests
import pandas as pd
from datasets import load_dataset, Dataset, concatenate_datasets
from sentence_transformers import SentenceTransformer, util
import gradio as gr
from transformers import pipeline
# =========================
# CONFIG
# =========================
MAIN_DATASET = "princemaxp/guardian-ai-qna"
KEYWORDS_DATASET = "princemaxp/cybersecurity-keywords"
# Tokens from secrets
HF_MODEL_TOKEN = os.environ.get("HF_TOKEN")
MAIN_DATASET_TOKEN = os.environ.get("DATASET_HF_TOKEN")
TRENDYOL_TOKEN = os.environ.get("DATASET_TRENDYOL_TOKEN")
ROW_TOKEN = os.environ.get("DATASET_ROW_CYBERQA_TOKEN")
SHAREGPT_TOKEN = os.environ.get("DATASET_SHAREGPT_TOKEN")
RENDER_API_URL = os.environ.get("RENDER_API_URL")
# =========================
# LOAD DATASETS
# =========================
# Main dataset (writable)
try:
main_ds = load_dataset(MAIN_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
except Exception:
main_ds = Dataset.from_dict({"question": [], "answer": []})
# External datasets (read-only)
external_datasets = [
("trendyol/cybersecurity-defense-v2", TRENDYOL_TOKEN),
("Rowden/CybersecurityQAA", ROW_TOKEN),
("Nitral-AI/Cybersecurity-ShareGPT", SHAREGPT_TOKEN),
]
ext_ds_list = []
for name, token in external_datasets:
try:
ds = load_dataset(name, split="train", use_auth_token=token)
ext_ds_list.append(ds)
except Exception as e:
print(f"⚠ Could not load {name}: {e}")
# Keyword dataset (CSV)
try:
kw_ds = load_dataset(KEYWORDS_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
keywords = set(kw_ds["keyword"])
except Exception as e:
print(f"⚠ Could not load keywords dataset: {e}")
keywords = set()
# =========================
# MODELS
# =========================
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
chat_model = pipeline("text-generation", model="gpt2", token=HF_MODEL_TOKEN)
# Precompute embeddings
def compute_embeddings(dataset):
if len(dataset) == 0:
return None
return embedder.encode(dataset["question"], convert_to_tensor=True)
main_embs = compute_embeddings(main_ds)
ext_embs = [compute_embeddings(ds) for ds in ext_ds_list]
# =========================
# HELPERS
# =========================
def is_cybersecurity_question(text: str) -> bool:
words = text.lower().split()
return any(kw.lower() in words for kw in keywords)
def search_dataset(user_query, dataset, dataset_embeddings, threshold=0.7):
if dataset_embeddings is None or len(dataset) == 0:
return None
query_emb = embedder.encode(user_query, convert_to_tensor=True)
cos_sim = util.cos_sim(query_emb, dataset_embeddings)[0]
best_idx = cos_sim.argmax().item()
best_score = cos_sim[best_idx].item()
if best_score >= threshold:
return dataset[best_idx]["answer"]
return None
def call_render(question):
try:
resp = requests.post(RENDER_API_URL, json={"question": question}, timeout=15)
if resp.status_code == 200:
return resp.json().get("answer", "No answer from Render.")
return f"Render error: {resp.status_code}"
except Exception as e:
return f"Render request failed: {e}"
def save_to_main(question, answer):
global main_ds, main_embs
new_row = {"question": [question], "answer": [answer]}
new_ds = Dataset.from_dict(new_row)
main_ds = concatenate_datasets([main_ds, new_ds])
main_embs = compute_embeddings(main_ds)
main_ds.push_to_hub(MAIN_DATASET, token=MAIN_DATASET_TOKEN)
# =========================
# ANSWER FLOW
# =========================
def get_answer(user_query):
if is_cybersecurity_question(user_query):
# 1. Check in main dataset
ans = search_dataset(user_query, main_ds, main_embs)
if ans:
return ans
# 2. Check external datasets
for ds, emb in zip(ext_ds_list, ext_embs):
ans = search_dataset(user_query, ds, emb)
if ans:
save_to_main(user_query, ans)
return ans
# 3. Fallback: Render
ans = call_render(user_query)
save_to_main(user_query, ans)
return ans
else:
# General Q → use chat model
gen = chat_model(user_query, max_length=100, num_return_sequences=1)
return gen[0]["generated_text"]
# =========================
# GRADIO UI
# =========================
def chatbot(user_input):
return get_answer(user_input)
iface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=2, placeholder="Ask me anything..."),
outputs="text",
title="Guardian AI Chatbot",
description="Cybersecurity-focused chatbot with general fallback"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)