Spaces:
Sleeping
Sleeping
import os | |
import json | |
import csv | |
from io import StringIO | |
from functools import wraps | |
from datetime import datetime | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from dotenv import load_dotenv | |
# Firebase Admin SDK | |
import firebase_admin | |
from firebase_admin import credentials, auth, storage, db as firebase_db, initialize_app | |
# Exa.ai | |
from exa_py import Exa | |
# Google GenAI (Gemini) | |
from google import genai | |
# βββ Load environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
load_dotenv() | |
EXA_API_KEY = os.getenv("EXA_API_KEY") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
FIREBASE_JSON = os.getenv("FIREBASE") | |
FIREBASE_DB_URL = os.getenv("Firebase_DB") | |
STORAGE_BUCKET = os.getenv("Firebase_Storage") | |
if not (EXA_API_KEY and GEMINI_API_KEY and FIREBASE_JSON and FIREBASE_DB_URL and STORAGE_BUCKET): | |
raise RuntimeError("Missing one or more required env vars.") | |
# βββ Initialize Firebase ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
cred = credentials.Certificate(json.loads(FIREBASE_JSON)) | |
initialize_app(cred, { | |
"storageBucket": STORAGE_BUCKET, | |
"databaseURL": FIREBASE_DB_URL | |
}) | |
bucket = storage.bucket() | |
# βββ Ensure dummy admin exists βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
try: | |
admin_user = auth.get_user_by_email("gsamukange@yahoo.com") | |
except firebase_admin._auth_utils.UserNotFoundError: | |
admin_user = auth.create_user( | |
email="gsamukange@yahoo.com", | |
email_verified=True, | |
password="marco2025", | |
display_name="Admin" | |
) | |
auth.set_custom_user_claims(admin_user.uid, {"admin": True}) | |
# βββ Initialize Exa.ai ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
exa = Exa(EXA_API_KEY) | |
# βββ Initialize Gemini Client βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
MODEL = "gemini-2.0-flash-001" | |
# βββ Flask App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
app = Flask(__name__) | |
CORS(app) | |
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def verify_id_token(f): | |
def wrapper(*args, **kwargs): | |
auth_header = request.headers.get("Authorization", "") | |
if not auth_header.startswith("Bearer "): | |
return jsonify({"error": "Missing or invalid Authorization header"}), 401 | |
id_token = auth_header.split(" ")[1] | |
try: | |
decoded = auth.verify_id_token(id_token) | |
request.user = decoded | |
except Exception: | |
return jsonify({"error": "Invalid or expired token"}), 401 | |
return f(*args, **kwargs) | |
return wrapper | |
def require_admin(f): | |
def wrapper(*args, **kwargs): | |
if not getattr(request, "user", None) or not request.user.get("admin", False): | |
return jsonify({"error": "Admin privileges required"}), 403 | |
return f(*args, **kwargs) | |
return wrapper | |
def increment_query_count(ip): | |
sanitized_ip = ip.replace(".", "_") | |
ref = firebase_db.reference(f"ip_queries/{sanitized_ip}") | |
current = ref.get() | |
count = current["count"] + 1 if current else 1 | |
ref.set({"count": count}) | |
return count | |
def store_chat(ip, message, response): | |
ref = firebase_db.reference("chats") | |
ref.push({ | |
"ip": ip, | |
"message": message, | |
"response": response, | |
"timestamp": datetime.utcnow().isoformat() | |
}) | |
def get_user_record(uid): | |
return firebase_db.reference(f"users/{uid}").get() | |
# βββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def chat(): | |
user_ip = request.remote_addr or "0.0.0.0" | |
user_ip = user_ip.replace(".", "_") | |
count = increment_query_count(user_ip) | |
need_login = count > 5 | |
token = request.headers.get("Authorization", "") | |
user_info = None | |
if token.startswith("Bearer "): | |
try: | |
user_info = auth.verify_id_token(token.split(" ")[1]) | |
except: | |
user_info = None | |
if need_login and not user_info: | |
return jsonify({ | |
"error": "Please log in to continue after 5 free queries", | |
"login_required": True | |
}), 403 | |
data = request.get_json() | |
user_message = data.get("message", "").strip() | |
classify_prompt = ( | |
"Categorize the user's intent as JSON list of one or both:\n" | |
"['self-help','product_search']\n\n" | |
f"User: \"{user_message}\"" | |
) | |
classify_resp = client.models.generate_content(model=MODEL, contents=classify_prompt) | |
try: | |
intents = json.loads(classify_resp.text) | |
except: | |
intents = ["self-help"] | |
parts = [] | |
if "self-help" in intents: | |
help_prompt = f"You are a helpful assistant. Give step-by-step guidance for: \"{user_message}\"" | |
help_resp = client.models.generate_content(model=MODEL, contents=help_prompt) | |
parts.append(help_resp.text.strip()) | |
search_q = user_message + " ingredients" | |
exa_res = exa.search_and_contents(search_q, type="auto", text=True) | |
links = [r.url for r in exa_res.results[:3]] | |
if links: | |
parts.append("Here are some useful links to get supplies:\n" + "\n".join(f"- {u}" for u in links)) | |
if "product_search" in intents and "self-help" not in intents: | |
exa_res = exa.search_and_contents(user_message, type="auto", text=True) | |
links = [r.url for r in exa_res.results[:5]] | |
rec_prompt = ( | |
"You are a shopping assistant. Suggest products for:\n" | |
f"\"{user_message}\"\n" | |
"Always include links:\n" + "\n".join(links) | |
) | |
rec_resp = client.models.generate_content(model=MODEL, contents=rec_prompt) | |
parts.append(rec_resp.text.strip()) | |
if not parts: | |
fallback = client.models.generate_content(model=MODEL, contents=f"Help the user with: \"{user_message}\"") | |
parts.append(fallback.text.strip()) | |
final_response = "\n\n".join(parts) | |
store_chat(user_ip, user_message, final_response) | |
return jsonify({"response": final_response}) | |
def signup(): | |
data = request.get_json() | |
email = data.get("email") | |
pwd = data.get("password") | |
try: | |
user = auth.create_user(email=email, password=pwd) | |
firebase_db.reference(f"users/{user.uid}").set({ | |
"email": email, | |
"preferences": {} | |
}) | |
return jsonify({"message": "User created"}), 201 | |
except Exception as e: | |
return jsonify({"error": str(e)}), 400 | |
def user_dashboard(): | |
uid = request.user["uid"] | |
user_data = get_user_record(uid) | |
if not user_data: | |
return jsonify({"error": "User not found"}), 404 | |
return jsonify({ | |
"email": user_data.get("email"), | |
"preferences": user_data.get("preferences", {}) | |
}) | |
def update_preferences(): | |
uid = request.user["uid"] | |
prefs = request.get_json().get("preferences", {}) | |
ref = firebase_db.reference(f"users/{uid}/preferences") | |
ref.set(prefs) | |
return jsonify({"message": "Preferences updated"}) | |
# ββ Admin Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def upload_links(): | |
if "file" not in request.files: | |
return jsonify({"error": "CSV file required"}), 400 | |
f = request.files["file"] | |
stream = StringIO(f.stream.read().decode("utf-8")) | |
reader = csv.DictReader(stream) | |
ref = firebase_db.reference("sponsoredLinks") | |
for row in reader: | |
ref.push({ | |
"keyword": row.get("keyword"), | |
"url": row.get("url") | |
}) | |
return jsonify({"message": "Links uploaded from CSV"}) | |
def add_link(): | |
data = request.get_json() | |
keyword = data.get("keyword") | |
url = data.get("url") | |
if not (keyword and url): | |
return jsonify({"error": "Both 'keyword' and 'url' are required"}), 400 | |
ref = firebase_db.reference("sponsoredLinks") | |
ref.push({"keyword": keyword, "url": url}) | |
return jsonify({"message": "Link added successfully"}) | |
def update_link(link_id): | |
data = request.get_json() | |
keyword = data.get("keyword") | |
url = data.get("url") | |
if not (keyword and url): | |
return jsonify({"error": "Both 'keyword' and 'url' are required"}), 400 | |
ref = firebase_db.reference(f"sponsoredLinks/{link_id}") | |
if not ref.get(): | |
return jsonify({"error": "Link not found"}), 404 | |
ref.update({"keyword": keyword, "url": url}) | |
return jsonify({"message": "Link updated successfully"}) | |
def get_sponsored_links(): | |
ref = firebase_db.reference("sponsoredLinks") | |
data = ref.get() or {} | |
# Convert to a list of objects with id included | |
links = [{"id": key, "keyword": val.get("keyword"), "url": val.get("url")} | |
for key, val in data.items()] | |
return jsonify({"sponsored_links": links}) | |
def delete_link(link_id): | |
ref = firebase_db.reference(f"sponsoredLinks/{link_id}") | |
if not ref.get(): | |
return jsonify({"error": "Link not found"}), 404 | |
ref.delete() | |
return jsonify({"message": "Link deleted successfully"}) | |
def stats(): | |
users = auth.list_users().users | |
total_users = len(users) | |
ip_data = firebase_db.reference("ip_queries").get() or {} | |
total_queries = sum(val.get("count", 0) for val in ip_data.values()) | |
return jsonify({ | |
"total_users": total_users, | |
"total_queries": total_queries | |
}) | |
# βββ Run βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=True) |