radames's picture
radames HF staff
fixes
1d06c9e
import os
from pathlib import Path
import json
from flask import Flask, request, jsonify, g
from flask_expects_json import expects_json
from flask_cors import CORS
from PIL import Image
from huggingface_hub import Repository
from flask_apscheduler import APScheduler
import shutil
import sqlite3
import subprocess
from jsonschema import ValidationError
MODE = os.environ.get("FLASK_ENV", "production")
IS_DEV = MODE == "development"
app = Flask(__name__, static_url_path="/static")
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False
schema = {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"images": {
"type": "array",
"items": {
"type": "object",
"minProperties": 2,
"maxProperties": 2,
"properties": {
"colors": {
"type": "array",
"items": {"type": "string"},
"maxItems": 5,
"minItems": 5,
},
"imgURL": {"type": "string"},
},
},
},
},
"minProperties": 2,
"maxProperties": 2,
}
CORS(app)
DB_FILE = Path("./data.db")
TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN")
repo = Repository(
local_dir="data",
repo_type="dataset",
clone_from="huggingface-projects/color-palettes-sd",
use_auth_token=TOKEN,
)
repo.git_pull()
# copy db on db to local path
shutil.copyfile("./data/data.db", DB_FILE)
db = sqlite3.connect(DB_FILE)
try:
data = db.execute("SELECT * FROM palettes").fetchall()
if IS_DEV:
print(f"Loaded {len(data)} palettes from local db")
db.close()
except sqlite3.OperationalError:
db.execute(
"CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)"
)
db.commit()
def get_db():
db = getattr(g, "_database", None)
if db is None:
db = g._database = sqlite3.connect(DB_FILE)
db.row_factory = sqlite3.Row
return db
@app.teardown_appcontext
def close_connection(exception):
db = getattr(g, "_database", None)
if db is not None:
db.close()
def update_repository():
repo.git_pull()
# copy db on db to local path
shutil.copyfile(DB_FILE, "./data/data.db")
with sqlite3.connect("./data/data.db") as db:
db.row_factory = sqlite3.Row
palettes = db.execute("SELECT * FROM palettes").fetchall()
data = [
{
"id": row["id"],
"data": json.loads(row["data"]),
"created_at": row["created_at"],
}
for row in palettes
]
with open("./data/data.json", "w") as f:
json.dump(data, f, separators=(",", ":"))
print("Updating repository")
subprocess.Popen(
"git add . && git commit --amend -m 'update' && git push --force",
cwd="./data",
shell=True,
)
repo.push_to_hub(blocking=False)
@app.route("/")
def index():
return app.send_static_file("index.html")
@app.route("/force_push")
def push():
if request.headers["token"] == TOKEN:
update_repository()
return jsonify({"success": True})
else:
return "Error", 401
def getAllData():
palettes = get_db().execute("SELECT * FROM palettes").fetchall()
data = [
{
"id": row["id"],
"data": json.loads(row["data"]),
"created_at": row["created_at"],
}
for row in palettes
]
return data
@app.route("/data")
def getdata():
return jsonify(getAllData())
@app.route("/new_palette", methods=["POST"])
@expects_json(schema)
def create():
data = g.data
db = get_db()
cursor = db.cursor()
cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)])
db.commit()
return jsonify(getAllData())
@app.errorhandler(400)
def bad_request(error):
if isinstance(error.description, ValidationError):
original_error = error.description
return jsonify({"error": original_error.message}), 400
return error
if __name__ == "__main__":
if not IS_DEV:
print("Starting scheduler -- Running Production")
scheduler = APScheduler()
scheduler.add_job(
id="Update Dataset Repository",
func=update_repository,
trigger="interval",
hours=1,
)
scheduler.start()
else:
print("Not Starting scheduler -- Running Development")
app.run(
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
debug=True,
use_reloader=IS_DEV,
)