import os from pathlib import Path import json from flask import Flask, request, jsonify, g from flask_expects_json import expects_json from flask_cors import CORS from PIL import Image from huggingface_hub import Repository from flask_apscheduler import APScheduler import shutil import sqlite3 import subprocess from jsonschema import ValidationError MODE = os.environ.get("FLASK_ENV", "production") IS_DEV = MODE == "development" app = Flask(__name__, static_url_path="/static") app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False schema = { "type": "object", "properties": { "prompt": {"type": "string"}, "images": { "type": "array", "items": { "type": "object", "minProperties": 2, "maxProperties": 2, "properties": { "colors": { "type": "array", "items": {"type": "string"}, "maxItems": 5, "minItems": 5, }, "imgURL": {"type": "string"}, }, }, }, }, "minProperties": 2, "maxProperties": 2, } CORS(app) DB_FILE = Path("./data.db") TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN") repo = Repository( local_dir="data", repo_type="dataset", clone_from="huggingface-projects/color-palettes-sd", use_auth_token=TOKEN, ) repo.git_pull() # copy db on db to local path shutil.copyfile("./data/data.db", DB_FILE) db = sqlite3.connect(DB_FILE) try: data = db.execute("SELECT * FROM palettes").fetchall() if IS_DEV: print(f"Loaded {len(data)} palettes from local db") db.close() except sqlite3.OperationalError: db.execute( "CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)" ) db.commit() def get_db(): db = getattr(g, "_database", None) if db is None: db = g._database = sqlite3.connect(DB_FILE) db.row_factory = sqlite3.Row return db @app.teardown_appcontext def close_connection(exception): db = getattr(g, "_database", None) if db is not None: db.close() def update_repository(): repo.git_pull() # copy db on db to local path shutil.copyfile(DB_FILE, "./data/data.db") with sqlite3.connect("./data/data.db") as db: db.row_factory = sqlite3.Row palettes = db.execute("SELECT * FROM palettes").fetchall() data = [ { "id": row["id"], "data": json.loads(row["data"]), "created_at": row["created_at"], } for row in palettes ] with open("./data/data.json", "w") as f: json.dump(data, f, separators=(",", ":")) print("Updating repository") subprocess.Popen( "git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True, ) repo.push_to_hub(blocking=False) @app.route("/") def index(): return app.send_static_file("index.html") @app.route("/force_push") def push(): if request.headers["token"] == TOKEN: update_repository() return jsonify({"success": True}) else: return "Error", 401 def getAllData(): palettes = get_db().execute("SELECT * FROM palettes").fetchall() data = [ { "id": row["id"], "data": json.loads(row["data"]), "created_at": row["created_at"], } for row in palettes ] return data @app.route("/data") def getdata(): return jsonify(getAllData()) @app.route("/new_palette", methods=["POST"]) @expects_json(schema) def create(): data = g.data db = get_db() cursor = db.cursor() cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) db.commit() return jsonify(getAllData()) @app.errorhandler(400) def bad_request(error): if isinstance(error.description, ValidationError): original_error = error.description return jsonify({"error": original_error.message}), 400 return error if __name__ == "__main__": if not IS_DEV: print("Starting scheduler -- Running Production") scheduler = APScheduler() scheduler.add_job( id="Update Dataset Repository", func=update_repository, trigger="interval", hours=1, ) scheduler.start() else: print("Not Starting scheduler -- Running Development") app.run( host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=True, use_reloader=IS_DEV, )