File size: 3,727 Bytes
47cfe3f
 
 
 
 
 
 
e42e2d6
47cfe3f
e42e2d6
cd4aa38
fc59b67
e42e2d6
 
47cfe3f
 
 
 
 
 
e42e2d6
cd4aa38
e42e2d6
 
 
 
 
 
 
 
cd4aa38
 
e42e2d6
fc59b67
e42e2d6
 
cd4aa38
47cfe3f
 
e42e2d6
 
 
 
 
 
 
 
47cfe3f
 
 
 
 
 
fc59b67
 
 
47cfe3f
 
 
 
 
 
 
 
e42e2d6
cd4aa38
 
 
 
 
e42e2d6
 
 
 
 
 
 
 
 
 
 
47cfe3f
 
 
 
fc59b67
cd4aa38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213e233
47cfe3f
 
 
 
 
e42e2d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd1b7e7
c7d9f1b
cd1b7e7
 
213e233
 
cd1b7e7
 
213e233
cd1b7e7
e42e2d6
213e233
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from datasets import load_dataset
import json
import uuid
from pathlib import Path
import json
from datasets import load_dataset
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_apscheduler import APScheduler
import shutil
from PIL import Image
import sqlite3
from huggingface_hub import Repository

app = Flask(__name__, static_url_path='/static')

CORS(app)

TOKEN = os.environ.get('dataset_token')

DB_FILE = Path("./prompts.db")

repo = Repository(
    local_dir="data",
    repo_type="dataset",
    clone_from="huggingface-projects/wordalle_guesses",
    use_auth_token=TOKEN
)
repo.git_pull()
# copy db on db to local path
shutil.copyfile("./data/prompts.db", DB_FILE)

dataset = load_dataset(
    "huggingface-projects/wordalle_prompts",
    use_auth_token=TOKEN)

Path("static/images").mkdir(parents=True, exist_ok=True)

db = sqlite3.connect(DB_FILE)
try:
    data = db.execute("SELECT * FROM prompts").fetchall()
    db.close()
except sqlite3.OperationalError:
    db.execute('CREATE TABLE prompts (guess TEXT, correct TEXT)')
    db.commit()

# extract images and prompts from dataset and save to dis
data = {}
for row in dataset['train']:
    prompt = dataset['train'].features['label'].int2str(row['label'])
    image = row['image']
    hash = uuid.uuid4().hex
    image_file = Path(f'static/images/{hash}.jpg')
    image_compress = image.resize((136, 136), Image.Resampling.LANCZOS)
    image_compress.save(image_file, optimize=True, quality=95)
    if prompt not in data:
        data[prompt] = []
    data[prompt].append(str(image_file))

with open('static/data.json', 'w') as f:
    json.dump(data, f)


def update_repository():
    repo.git_pull()
    # copy db on db to local path
    shutil.copyfile(DB_FILE, "./data/prompts.db")

    with sqlite3.connect("./data/prompts.db") as db:
        db.row_factory = sqlite3.Row
        result = db.execute("SELECT * FROM prompts").fetchall()
        data = [dict(row) for row in result]

    with open('./data/data.json', 'w') as f:
        json.dump(data, f, separators=(',', ':'))

    print("Updating repository")
    repo.push_to_hub(blocking=False)


@app.route('/')
def index():
    return app.send_static_file('index.html')


@app.route('/force_push')
def push():
    if(request.headers['token'] == TOKEN):
        print("Force Push repository")
        shutil.copyfile(DB_FILE, "./data/prompts.db")
        oldpwd = os.getcwd()
        os.chdir("./data")
        os.system("git add .")
        os.system("git commit -m 'force push'")
        os.system("git push --force")
        os.chdir(oldpwd)
        return "Success", 200
    else:
        return "Error", 401


@app.route('/data')
def getdata():
    return app.send_static_file('data.json')


@app.route('/prompt', methods=['POST', 'GET'])
def create():
    if request.method == 'POST':
        try:
            data = request.get_json()
            guess = data['guess']
            correct = data['correct']
            with sqlite3.connect(DB_FILE) as db:
                db.execute(
                    'INSERT INTO prompts (guess, correct) VALUES (?, ?)', (guess, correct))
                db.commit()
            return 'OK', 200
        except:
            return 'Missing guess or correct', 400


if __name__ == '__main__':
    mode = os.environ.get('FLASK_ENV', 'production')
    print(mode)
    dev = mode == 'development'
    if not dev:
        scheduler = APScheduler()
        scheduler.add_job(id='Update Dataset Repository',
                          func=update_repository, trigger='interval', seconds=300)
        scheduler.start()
    app.run(host='0.0.0.0',  port=int(
        os.environ.get('PORT', 7860)), debug=True, use_reloader=dev)