Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import requests | |
| import pandas as pd | |
| from datasets import load_dataset, Dataset, concatenate_datasets | |
| from sentence_transformers import SentenceTransformer, util | |
| import gradio as gr | |
| from transformers import pipeline | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| MAIN_DATASET = "princemaxp/guardian-ai-qna" | |
| KEYWORDS_DATASET = "princemaxp/cybersecurity-keywords" | |
| # Tokens from secrets | |
| HF_MODEL_TOKEN = os.environ.get("HF_TOKEN") | |
| MAIN_DATASET_TOKEN = os.environ.get("DATASET_HF_TOKEN") | |
| TRENDYOL_TOKEN = os.environ.get("DATASET_TRENDYOL_TOKEN") | |
| ROW_TOKEN = os.environ.get("DATASET_ROW_CYBERQA_TOKEN") | |
| SHAREGPT_TOKEN = os.environ.get("DATASET_SHAREGPT_TOKEN") | |
| RENDER_API_URL = os.environ.get("RENDER_API_URL") | |
| # ========================= | |
| # LOAD DATASETS | |
| # ========================= | |
| # Main dataset (writable) | |
| try: | |
| main_ds = load_dataset(MAIN_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN) | |
| except Exception: | |
| main_ds = Dataset.from_dict({"question": [], "answer": []}) | |
| # External datasets (read-only) | |
| external_datasets = [ | |
| ("trendyol/cybersecurity-defense-v2", TRENDYOL_TOKEN), | |
| ("Rowden/CybersecurityQAA", ROW_TOKEN), | |
| ("Nitral-AI/Cybersecurity-ShareGPT", SHAREGPT_TOKEN), | |
| ] | |
| ext_ds_list = [] | |
| for name, token in external_datasets: | |
| try: | |
| ds = load_dataset(name, split="train", use_auth_token=token) | |
| ext_ds_list.append(ds) | |
| except Exception as e: | |
| print(f"⚠ Could not load {name}: {e}") | |
| # Keyword dataset (CSV) | |
| try: | |
| kw_ds = load_dataset(KEYWORDS_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN) | |
| keywords = set(kw_ds["keyword"]) | |
| except Exception as e: | |
| print(f"⚠ Could not load keywords dataset: {e}") | |
| keywords = set() | |
| # ========================= | |
| # MODELS | |
| # ========================= | |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| chat_model = pipeline("text-generation", model="gpt2", token=HF_MODEL_TOKEN) | |
| # Precompute embeddings | |
| def compute_embeddings(dataset): | |
| if len(dataset) == 0: | |
| return None | |
| return embedder.encode(dataset["question"], convert_to_tensor=True) | |
| main_embs = compute_embeddings(main_ds) | |
| ext_embs = [compute_embeddings(ds) for ds in ext_ds_list] | |
| # ========================= | |
| # HELPERS | |
| # ========================= | |
| def is_cybersecurity_question(text: str) -> bool: | |
| words = text.lower().split() | |
| return any(kw.lower() in words for kw in keywords) | |
| def search_dataset(user_query, dataset, dataset_embeddings, threshold=0.7): | |
| if dataset_embeddings is None or len(dataset) == 0: | |
| return None | |
| query_emb = embedder.encode(user_query, convert_to_tensor=True) | |
| cos_sim = util.cos_sim(query_emb, dataset_embeddings)[0] | |
| best_idx = cos_sim.argmax().item() | |
| best_score = cos_sim[best_idx].item() | |
| if best_score >= threshold: | |
| return dataset[best_idx]["answer"] | |
| return None | |
| def call_render(question): | |
| try: | |
| resp = requests.post(RENDER_API_URL, json={"question": question}, timeout=15) | |
| if resp.status_code == 200: | |
| return resp.json().get("answer", "No answer from Render.") | |
| return f"Render error: {resp.status_code}" | |
| except Exception as e: | |
| return f"Render request failed: {e}" | |
| def save_to_main(question, answer): | |
| global main_ds, main_embs | |
| new_row = {"question": [question], "answer": [answer]} | |
| new_ds = Dataset.from_dict(new_row) | |
| main_ds = concatenate_datasets([main_ds, new_ds]) | |
| main_embs = compute_embeddings(main_ds) | |
| main_ds.push_to_hub(MAIN_DATASET, token=MAIN_DATASET_TOKEN) | |
| # ========================= | |
| # ANSWER FLOW | |
| # ========================= | |
| def get_answer(user_query): | |
| if is_cybersecurity_question(user_query): | |
| # 1. Check in main dataset | |
| ans = search_dataset(user_query, main_ds, main_embs) | |
| if ans: | |
| return ans | |
| # 2. Check external datasets | |
| for ds, emb in zip(ext_ds_list, ext_embs): | |
| ans = search_dataset(user_query, ds, emb) | |
| if ans: | |
| save_to_main(user_query, ans) | |
| return ans | |
| # 3. Fallback: Render | |
| ans = call_render(user_query) | |
| save_to_main(user_query, ans) | |
| return ans | |
| else: | |
| # General Q → use chat model | |
| gen = chat_model(user_query, max_length=100, num_return_sequences=1) | |
| return gen[0]["generated_text"] | |
| # ========================= | |
| # GRADIO UI | |
| # ========================= | |
| def chatbot(user_input): | |
| return get_answer(user_input) | |
| iface = gr.Interface( | |
| fn=chatbot, | |
| inputs=gr.Textbox(lines=2, placeholder="Ask me anything..."), | |
| outputs="text", | |
| title="Guardian AI Chatbot", | |
| description="Cybersecurity-focused chatbot with general fallback" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |