|
import os |
|
from pathlib import Path |
|
import json |
|
from flask import Flask, request, jsonify, g |
|
from flask_expects_json import expects_json |
|
from flask_cors import CORS |
|
from PIL import Image |
|
from huggingface_hub import Repository |
|
from flask_apscheduler import APScheduler |
|
import shutil |
|
import sqlite3 |
|
import subprocess |
|
from jsonschema import ValidationError |
|
|
|
MODE = os.environ.get('FLASK_ENV', 'production') |
|
IS_DEV = MODE == 'development' |
|
app = Flask(__name__, static_url_path='/static') |
|
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False |
|
|
|
schema = { |
|
"type": "object", |
|
"properties": { |
|
"prompt": {"type": "string"}, |
|
"images": { |
|
"type": "array", |
|
"items": { |
|
"type": "object", |
|
"minProperties": 2, |
|
"maxProperties": 2, |
|
"properties": { |
|
"colors": { |
|
"type": "array", |
|
"items": { |
|
"type": "string" |
|
}, |
|
"maxItems": 5, |
|
"minItems": 5 |
|
}, |
|
"imgURL": {"type": "string"}} |
|
} |
|
} |
|
}, |
|
"minProperties": 2, |
|
"maxProperties": 2 |
|
} |
|
|
|
CORS(app) |
|
|
|
DB_FILE = Path("./data.db") |
|
TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') |
|
repo = Repository( |
|
local_dir="data", |
|
repo_type="dataset", |
|
clone_from="huggingface-projects/color-palettes-sd", |
|
use_auth_token=TOKEN |
|
) |
|
repo.git_pull() |
|
|
|
shutil.copyfile("./data/data.db", DB_FILE) |
|
|
|
db = sqlite3.connect(DB_FILE) |
|
try: |
|
data = db.execute("SELECT * FROM palettes").fetchall() |
|
if IS_DEV: |
|
print(f"Loaded {len(data)} palettes from local db") |
|
db.close() |
|
except sqlite3.OperationalError: |
|
db.execute( |
|
'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)') |
|
db.commit() |
|
|
|
|
|
def get_db(): |
|
db = getattr(g, '_database', None) |
|
if db is None: |
|
db = g._database = sqlite3.connect(DB_FILE) |
|
db.row_factory = sqlite3.Row |
|
return db |
|
|
|
|
|
@app.teardown_appcontext |
|
def close_connection(exception): |
|
db = getattr(g, '_database', None) |
|
if db is not None: |
|
db.close() |
|
|
|
|
|
def update_repository(): |
|
repo.git_pull() |
|
|
|
shutil.copyfile(DB_FILE, "./data/data.db") |
|
|
|
with sqlite3.connect("./data/data.db") as db: |
|
db.row_factory = sqlite3.Row |
|
palettes = db.execute("SELECT * FROM palettes").fetchall() |
|
data = [{'id': row['id'], 'data': json.loads( |
|
row['data']), 'created_at': row['created_at']} for row in palettes] |
|
|
|
with open('./data/data.json', 'w') as f: |
|
json.dump(data, f, separators=(',', ':')) |
|
|
|
print("Updating repository") |
|
subprocess.Popen( |
|
"git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True) |
|
repo.push_to_hub(blocking=False) |
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
return app.send_static_file('index.html') |
|
|
|
|
|
@app.route('/force_push') |
|
def push(): |
|
if (request.headers['token'] == TOKEN): |
|
update_repository() |
|
return jsonify({'success': True}) |
|
else: |
|
return "Error", 401 |
|
|
|
|
|
def getAllData(): |
|
palettes = get_db().execute("SELECT * FROM palettes").fetchall() |
|
data = [{'id': row['id'], 'data': json.loads( |
|
row['data']), 'created_at': row['created_at']} for row in palettes] |
|
return data |
|
|
|
|
|
@app.route('/data') |
|
def getdata(): |
|
return jsonify(getAllData()) |
|
|
|
|
|
@app.route('/new_palette', methods=['POST']) |
|
@expects_json(schema) |
|
def create(): |
|
data = g.data |
|
db = get_db() |
|
cursor = db.cursor() |
|
cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) |
|
db.commit() |
|
return jsonify(getAllData()) |
|
|
|
|
|
@app.errorhandler(400) |
|
def bad_request(error): |
|
if isinstance(error.description, ValidationError): |
|
original_error = error.description |
|
return jsonify({'error': original_error.message}), 400 |
|
return error |
|
|
|
|
|
if __name__ == '__main__': |
|
if not IS_DEV: |
|
print("Starting scheduler -- Running Production") |
|
scheduler = APScheduler() |
|
scheduler.add_job(id='Update Dataset Repository', |
|
func=update_repository, trigger='interval', hours=1) |
|
scheduler.start() |
|
else: |
|
print("Not Starting scheduler -- Running Development") |
|
app.run(host='0.0.0.0', port=int( |
|
os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV) |
|
|