Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template, session, redirect, url_for | |
from flask_session import Session | |
import google.generativeai as genai | |
import json | |
import uuid | |
import os | |
import logging | |
from utils.ai_helpers import generate_notebook, stream_notebook_generation, stream_notebook_edit, edit_notebook | |
from utils.notebook_helpers import format_notebook, extract_notebook_info | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - notegenie - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger() | |
app = Flask(__name__) | |
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "notegenie-secret-key-change-in-production") | |
app.config["SESSION_TYPE"] = "filesystem" | |
app.config["SESSION_PERMANENT"] = True | |
app.config["SESSION_USE_SIGNER"] = True | |
app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * 30 # 30 days | |
app.config["SESSION_FILE_DIR"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), "flask_session") | |
os.makedirs(app.config["SESSION_FILE_DIR"], exist_ok=True) # Ensure directory exists | |
# Set a more permissive file mode for session files to avoid permission issues | |
app.config["SESSION_FILE_MODE"] = 0o666 | |
Session(app) | |
# Map front-end model names to API model names | |
MODEL_MAPPING = { | |
"gemini-2.0-pro": "gemini-2.0-pro-exp-02-05", | |
"gemini-2.0-flash": "gemini-2.0-flash", | |
"gemini-2.0-flash-thinking": "gemini-2.0-flash-thinking-exp-01-21" | |
} | |
def get_api_model_name(frontend_model_name): | |
return MODEL_MAPPING.get(frontend_model_name, frontend_model_name) | |
# Function to get API key from different sources | |
def get_api_key(): | |
# Try to get from session first | |
api_key = session.get("api_key") | |
# Try to get from request header or param (for direct API calls) | |
if not api_key: | |
api_key = request.headers.get("X-API-Key") or request.args.get("api_key") | |
return api_key | |
def index(): | |
# Test session functionality | |
session["session_test"] = True | |
logger.info(f"Session check - variables: {list(session.keys())}") | |
return render_template("index.html") | |
def set_api_key(): | |
api_key = request.form.get("api_key") | |
if not api_key: | |
return jsonify({"success": False, "message": "API key is required"}), 400 | |
try: | |
# Test the API key | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel("gemini-2.0-pro-exp-02-05") | |
response = model.generate_content("Say 'API key is valid'") | |
# Store API key in session only | |
session.permanent = True | |
session["api_key"] = api_key | |
logger.info("API key successfully set and validated") | |
return jsonify({"success": True}) | |
except Exception as e: | |
logger.error(f"API key validation error: {str(e)}") | |
return jsonify({"success": False, "message": str(e)}), 400 | |
def generate_notebook_route(): | |
api_key = get_api_key() | |
if not api_key: | |
logger.warning("Generate notebook request without API key") | |
return jsonify({"success": False, "message": "API key not set"}), 401 | |
# Always configure genai with the API key for each request | |
genai.configure(api_key=api_key) | |
# Handle both GET (for streaming) and POST requests | |
if request.method == "GET": | |
prompt = request.args.get("prompt") | |
model_name = request.args.get("model", "gemini-2.0-pro") | |
stream = request.args.get("stream", "false").lower() == "true" | |
else: | |
data = request.json | |
prompt = data.get("prompt") | |
model_name = data.get("model", "gemini-2.0-pro") | |
stream = data.get("stream", False) | |
format_only = data.get("format_only", False) | |
if not prompt: | |
return jsonify({"success": False, "message": "Prompt is required"}), 400 | |
# Map the frontend model name to the API model name | |
api_model_name = get_api_model_name(model_name) | |
try: | |
# OPTIMIZATION: If format_only is True, skip the AI call and just format the provided content | |
if request.method == "POST" and format_only: | |
# Use client-provided content as is (it's already the AI response) | |
notebook_content = prompt | |
notebook_json = format_notebook(notebook_content) | |
notebook_info = extract_notebook_info(notebook_content) | |
return jsonify({ | |
"success": True, | |
"notebook": notebook_json, | |
"name": notebook_info["name"], | |
"description": notebook_info["description"] | |
}) | |
elif stream: | |
return stream_notebook_generation(prompt, api_model_name) | |
else: | |
notebook_content = generate_notebook(prompt, api_model_name) | |
notebook_json = format_notebook(notebook_content) | |
notebook_info = extract_notebook_info(notebook_content) | |
return jsonify({ | |
"success": True, | |
"notebook": notebook_json, | |
"name": notebook_info["name"], | |
"description": notebook_info["description"] | |
}) | |
except Exception as e: | |
logger.error(f"Error generating notebook: {str(e)}") | |
return jsonify({"success": False, "message": str(e)}), 500 | |
def prepare_edit_notebook(): | |
"""Store the notebook in the session for editing.""" | |
api_key = get_api_key() | |
if not api_key: | |
return jsonify({"success": False, "message": "API key not set"}), 401 | |
data = request.json | |
notebook_json = data.get("notebook") | |
if not notebook_json: | |
return jsonify({"success": False, "message": "Notebook content is required"}), 400 | |
# Store the notebook in the session for later access | |
session["current_notebook"] = json.dumps(notebook_json) # Store as JSON string | |
return jsonify({"success": True}) | |
def edit_notebook_route(): | |
api_key = get_api_key() | |
if not api_key: | |
return jsonify({"success": False, "message": "API key not set"}), 401 | |
# Always configure genai with the API key for each request | |
genai.configure(api_key=api_key) | |
# Get edit prompt and current notebook | |
if request.method == "GET": | |
edit_prompt = request.args.get("edit_prompt") | |
model_name = request.args.get("model", "gemini-2.0-pro") | |
stream = request.args.get("stream", "true").lower() == "true" | |
# For GET streaming requests, get notebook from session | |
notebook_json = session.get("current_notebook") | |
if notebook_json: | |
notebook_json = json.loads(notebook_json) # Parse JSON string back to dict | |
else: | |
data = request.json | |
edit_prompt = data.get("edit_prompt") | |
notebook_json = data.get("notebook") | |
model_name = data.get("model", "gemini-2.0-pro") | |
stream = data.get("stream", False) | |
if not edit_prompt: | |
return jsonify({"success": False, "message": "Edit prompt is required"}), 400 | |
if not notebook_json: | |
return jsonify({"success": False, "message": "No notebook available for editing. Please prepare the notebook first."}), 400 | |
# Map the frontend model name to the API model name | |
api_model_name = get_api_model_name(model_name) | |
try: | |
if stream: | |
return stream_notebook_edit(edit_prompt, notebook_json, api_model_name) | |
else: | |
# Non-streaming path (not used in current UI but kept for API completeness) | |
edited_content = edit_notebook(edit_prompt, notebook_json, api_model_name) | |
notebook_json = format_notebook(edited_content) | |
notebook_info = extract_notebook_info(edited_content) | |
return jsonify({ | |
"success": True, | |
"notebook": notebook_json, | |
"name": notebook_info["name"], | |
"description": notebook_info["description"] | |
}) | |
except Exception as e: | |
app.logger.error(f"Error editing notebook: {str(e)}") | |
return jsonify({"success": False, "message": str(e)}), 500 | |
def download_notebook(): | |
from flask import Response | |
data = request.json | |
notebook_json = data.get("notebook") | |
filename = data.get("filename", f"notebook_{uuid.uuid4()}.ipynb") | |
if not notebook_json: | |
return jsonify({"success": False, "message": "Notebook content is required"}), 400 | |
if not filename.endswith(".ipynb"): | |
filename += ".ipynb" | |
response = Response( | |
json.dumps(notebook_json, indent=2), | |
mimetype="application/json", | |
headers={"Content-Disposition": f"attachment;filename={filename}"} | |
) | |
return response | |
# Add a session diagnostic endpoint | |
def check_session(): | |
# For debugging only - would be disabled in production | |
session_data = { | |
"has_api_key": "api_key" in session, | |
"session_vars": list(session.keys()), | |
"session_file_dir_exists": os.path.exists(app.config["SESSION_FILE_DIR"]), | |
"session_file_dir_writable": os.access(app.config["SESSION_FILE_DIR"], os.W_OK), | |
} | |
# Check if running on Hugging Face Spaces | |
is_hf_space = "SPACE_ID" in os.environ | |
session_data["is_huggingface_space"] = is_hf_space | |
if is_hf_space: | |
logger.info("Running on Hugging Face Spaces environment") | |
return jsonify(session_data) | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 5000)) | |
app.run(host="0.0.0.0", port=port, debug=(os.environ.get("FLASK_ENV") == "development")) | |