import os from pathlib import Path import json from flask import Flask, request, jsonify, g from flask_expects_json import expects_json from flask_cors import CORS from PIL import Image from huggingface_hub import Repository from flask_apscheduler import APScheduler import shutil import sqlite3 import subprocess from jsonschema import ValidationError MODE = os.environ.get('FLASK_ENV', 'production') IS_DEV = MODE == 'development' app = Flask(__name__, static_url_path='/static') app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False schema = { "type": "object", "properties": { "prompt": {"type": "string"}, "images": { "type": "array", "items": { "type": "object", "minProperties": 2, "maxProperties": 2, "properties": { "colors": { "type": "array", "items": { "type": "string" }, "maxItems": 5, "minItems": 5 }, "imgURL": {"type": "string"}} } } }, "minProperties": 2, "maxProperties": 2 } CORS(app) DB_FILE = Path("./data.db") TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') repo = Repository( local_dir="data", repo_type="dataset", clone_from="huggingface-projects/color-palettes-sd", use_auth_token=TOKEN ) repo.git_pull() # copy db on db to local path shutil.copyfile("./data/data.db", DB_FILE) db = sqlite3.connect(DB_FILE) try: data = db.execute("SELECT * FROM palettes").fetchall() if IS_DEV: print(f"Loaded {len(data)} palettes from local db") db.close() except sqlite3.OperationalError: db.execute( 'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)') db.commit() def get_db(): db = getattr(g, '_database', None) if db is None: db = g._database = sqlite3.connect(DB_FILE) db.row_factory = sqlite3.Row return db @app.teardown_appcontext def close_connection(exception): db = getattr(g, '_database', None) if db is not None: db.close() def update_repository(): repo.git_pull() # copy db on db to local path shutil.copyfile(DB_FILE, "./data/data.db") with sqlite3.connect("./data/data.db") as db: db.row_factory = sqlite3.Row palettes = db.execute("SELECT * FROM palettes").fetchall() data = [{'id': row['id'], 'data': json.loads( row['data']), 'created_at': row['created_at']} for row in palettes] with open('./data/data.json', 'w') as f: json.dump(data, f, separators=(',', ':')) print("Updating repository") subprocess.Popen( "git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True) repo.push_to_hub(blocking=False) @app.route('/') def index(): return app.send_static_file('index.html') @app.route('/force_push') def push(): if (request.headers['token'] == TOKEN): update_repository() return jsonify({'success': True}) else: return "Error", 401 def getAllData(): palettes = get_db().execute("SELECT * FROM palettes").fetchall() data = [{'id': row['id'], 'data': json.loads( row['data']), 'created_at': row['created_at']} for row in palettes] return data @app.route('/data') def getdata(): return jsonify(getAllData()) @app.route('/new_palette', methods=['POST']) @expects_json(schema) def create(): data = g.data db = get_db() cursor = db.cursor() cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) db.commit() return jsonify(getAllData()) @app.errorhandler(400) def bad_request(error): if isinstance(error.description, ValidationError): original_error = error.description return jsonify({'error': original_error.message}), 400 return error if __name__ == '__main__': if not IS_DEV: print("Starting scheduler -- Running Production") scheduler = APScheduler() scheduler.add_job(id='Update Dataset Repository', func=update_repository, trigger='interval', hours=1) scheduler.start() else: print("Not Starting scheduler -- Running Development") app.run(host='0.0.0.0', port=int( os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV)