testaus3 / app.py
ismot's picture
Duplicate from huggingface-projects/color-palette-generator-sd
4e0a699
import os
from pathlib import Path
import json
from flask import Flask, request, jsonify, g
from flask_expects_json import expects_json
from flask_cors import CORS
from PIL import Image
from huggingface_hub import Repository
from flask_apscheduler import APScheduler
import shutil
import sqlite3
import subprocess
from jsonschema import ValidationError
MODE = os.environ.get('FLASK_ENV', 'production')
IS_DEV = MODE == 'development'
app = Flask(__name__, static_url_path='/static')
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
schema = {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"images": {
"type": "array",
"items": {
"type": "object",
"minProperties": 2,
"maxProperties": 2,
"properties": {
"colors": {
"type": "array",
"items": {
"type": "string"
},
"maxItems": 5,
"minItems": 5
},
"imgURL": {"type": "string"}}
}
}
},
"minProperties": 2,
"maxProperties": 2
}
CORS(app)
DB_FILE = Path("./data.db")
TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN')
repo = Repository(
local_dir="data",
repo_type="dataset",
clone_from="huggingface-projects/color-palettes-sd",
use_auth_token=TOKEN
)
repo.git_pull()
# copy db on db to local path
shutil.copyfile("./data/data.db", DB_FILE)
db = sqlite3.connect(DB_FILE)
try:
data = db.execute("SELECT * FROM palettes").fetchall()
if IS_DEV:
print(f"Loaded {len(data)} palettes from local db")
db.close()
except sqlite3.OperationalError:
db.execute(
'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)')
db.commit()
def get_db():
db = getattr(g, '_database', None)
if db is None:
db = g._database = sqlite3.connect(DB_FILE)
db.row_factory = sqlite3.Row
return db
@app.teardown_appcontext
def close_connection(exception):
db = getattr(g, '_database', None)
if db is not None:
db.close()
def update_repository():
repo.git_pull()
# copy db on db to local path
shutil.copyfile(DB_FILE, "./data/data.db")
with sqlite3.connect("./data/data.db") as db:
db.row_factory = sqlite3.Row
palettes = db.execute("SELECT * FROM palettes").fetchall()
data = [{'id': row['id'], 'data': json.loads(
row['data']), 'created_at': row['created_at']} for row in palettes]
with open('./data/data.json', 'w') as f:
json.dump(data, f, separators=(',', ':'))
print("Updating repository")
subprocess.Popen(
"git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True)
repo.push_to_hub(blocking=False)
@app.route('/')
def index():
return app.send_static_file('index.html')
@app.route('/force_push')
def push():
if (request.headers['token'] == TOKEN):
update_repository()
return jsonify({'success': True})
else:
return "Error", 401
def getAllData():
palettes = get_db().execute("SELECT * FROM palettes").fetchall()
data = [{'id': row['id'], 'data': json.loads(
row['data']), 'created_at': row['created_at']} for row in palettes]
return data
@app.route('/data')
def getdata():
return jsonify(getAllData())
@app.route('/new_palette', methods=['POST'])
@expects_json(schema)
def create():
data = g.data
db = get_db()
cursor = db.cursor()
cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)])
db.commit()
return jsonify(getAllData())
@app.errorhandler(400)
def bad_request(error):
if isinstance(error.description, ValidationError):
original_error = error.description
return jsonify({'error': original_error.message}), 400
return error
if __name__ == '__main__':
if not IS_DEV:
print("Starting scheduler -- Running Production")
scheduler = APScheduler()
scheduler.add_job(id='Update Dataset Repository',
func=update_repository, trigger='interval', hours=1)
scheduler.start()
else:
print("Not Starting scheduler -- Running Development")
app.run(host='0.0.0.0', port=int(
os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV)