|
import os |
|
from pathlib import Path |
|
import json |
|
from flask import Flask, request, jsonify, g |
|
from flask_expects_json import expects_json |
|
from flask_cors import CORS |
|
from PIL import Image |
|
from huggingface_hub import Repository |
|
from flask_apscheduler import APScheduler |
|
import shutil |
|
import sqlite3 |
|
import subprocess |
|
from jsonschema import ValidationError |
|
|
|
MODE = os.environ.get("FLASK_ENV", "production") |
|
IS_DEV = MODE == "development" |
|
app = Flask(__name__, static_url_path="/static") |
|
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False |
|
|
|
schema = { |
|
"type": "object", |
|
"properties": { |
|
"prompt": {"type": "string"}, |
|
"images": { |
|
"type": "array", |
|
"items": { |
|
"type": "object", |
|
"minProperties": 2, |
|
"maxProperties": 2, |
|
"properties": { |
|
"colors": { |
|
"type": "array", |
|
"items": {"type": "string"}, |
|
"maxItems": 5, |
|
"minItems": 5, |
|
}, |
|
"imgURL": {"type": "string"}, |
|
}, |
|
}, |
|
}, |
|
}, |
|
"minProperties": 2, |
|
"maxProperties": 2, |
|
} |
|
|
|
CORS(app) |
|
|
|
DB_FILE = Path("./data.db") |
|
TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
repo = Repository( |
|
local_dir="data", |
|
repo_type="dataset", |
|
clone_from="huggingface-projects/color-palettes-sd", |
|
use_auth_token=TOKEN, |
|
) |
|
repo.git_pull() |
|
|
|
shutil.copyfile("./data/data.db", DB_FILE) |
|
|
|
db = sqlite3.connect(DB_FILE) |
|
try: |
|
data = db.execute("SELECT * FROM palettes").fetchall() |
|
if IS_DEV: |
|
print(f"Loaded {len(data)} palettes from local db") |
|
db.close() |
|
except sqlite3.OperationalError: |
|
db.execute( |
|
"CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)" |
|
) |
|
db.commit() |
|
|
|
|
|
def get_db(): |
|
db = getattr(g, "_database", None) |
|
if db is None: |
|
db = g._database = sqlite3.connect(DB_FILE) |
|
db.row_factory = sqlite3.Row |
|
return db |
|
|
|
|
|
@app.teardown_appcontext |
|
def close_connection(exception): |
|
db = getattr(g, "_database", None) |
|
if db is not None: |
|
db.close() |
|
|
|
|
|
def update_repository(): |
|
repo.git_pull() |
|
|
|
shutil.copyfile(DB_FILE, "./data/data.db") |
|
|
|
with sqlite3.connect("./data/data.db") as db: |
|
db.row_factory = sqlite3.Row |
|
palettes = db.execute("SELECT * FROM palettes").fetchall() |
|
data = [ |
|
{ |
|
"id": row["id"], |
|
"data": json.loads(row["data"]), |
|
"created_at": row["created_at"], |
|
} |
|
for row in palettes |
|
] |
|
|
|
with open("./data/data.json", "w") as f: |
|
json.dump(data, f, separators=(",", ":")) |
|
|
|
print("Updating repository") |
|
subprocess.Popen( |
|
"git add . && git commit --amend -m 'update' && git push --force", |
|
cwd="./data", |
|
shell=True, |
|
) |
|
repo.push_to_hub(blocking=False) |
|
|
|
|
|
@app.route("/") |
|
def index(): |
|
return app.send_static_file("index.html") |
|
|
|
|
|
@app.route("/force_push") |
|
def push(): |
|
if request.headers["token"] == TOKEN: |
|
update_repository() |
|
return jsonify({"success": True}) |
|
else: |
|
return "Error", 401 |
|
|
|
|
|
def getAllData(): |
|
palettes = get_db().execute("SELECT * FROM palettes").fetchall() |
|
data = [ |
|
{ |
|
"id": row["id"], |
|
"data": json.loads(row["data"]), |
|
"created_at": row["created_at"], |
|
} |
|
for row in palettes |
|
] |
|
return data |
|
|
|
|
|
@app.route("/data") |
|
def getdata(): |
|
return jsonify(getAllData()) |
|
|
|
|
|
@app.route("/new_palette", methods=["POST"]) |
|
@expects_json(schema) |
|
def create(): |
|
data = g.data |
|
db = get_db() |
|
cursor = db.cursor() |
|
cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) |
|
db.commit() |
|
return jsonify(getAllData()) |
|
|
|
|
|
@app.errorhandler(400) |
|
def bad_request(error): |
|
if isinstance(error.description, ValidationError): |
|
original_error = error.description |
|
return jsonify({"error": original_error.message}), 400 |
|
return error |
|
|
|
|
|
if __name__ == "__main__": |
|
if not IS_DEV: |
|
print("Starting scheduler -- Running Production") |
|
scheduler = APScheduler() |
|
scheduler.add_job( |
|
id="Update Dataset Repository", |
|
func=update_repository, |
|
trigger="interval", |
|
hours=1, |
|
) |
|
scheduler.start() |
|
else: |
|
print("Not Starting scheduler -- Running Development") |
|
app.run( |
|
host="0.0.0.0", |
|
port=int(os.environ.get("PORT", 7860)), |
|
debug=True, |
|
use_reloader=IS_DEV, |
|
) |
|
|