File size: 2,963 Bytes
47cfe3f
 
 
 
 
 
 
e42e2d6
47cfe3f
e42e2d6
 
fc59b67
e42e2d6
 
47cfe3f
 
 
 
 
 
e42e2d6
 
 
 
 
 
 
 
 
 
 
fc59b67
e42e2d6
 
47cfe3f
 
e42e2d6
 
 
 
 
 
 
 
 
47cfe3f
 
 
 
 
 
fc59b67
 
 
47cfe3f
 
 
 
 
 
 
 
e42e2d6
 
 
 
 
 
 
 
 
 
 
 
 
 
47cfe3f
 
 
 
fc59b67
47cfe3f
 
 
 
 
e42e2d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d9f1b
e42e2d6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from datasets import load_dataset
import json
import uuid
from pathlib import Path
import json
from datasets import load_dataset
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_apscheduler import APScheduler

from PIL import Image
import sqlite3
from huggingface_hub import Repository

app = Flask(__name__, static_url_path='/static')

CORS(app)

TOKEN = os.environ.get('dataset_token')

DB_FILE = Path("./data/prompts.db")

repo = Repository(
    local_dir="data",
    repo_type="dataset",
    clone_from="huggingface-projects/wordalle_guesses",
    use_auth_token=TOKEN
)
repo.git_pull()

dataset = load_dataset(
    "huggingface-projects/wordalle_prompts",
    use_auth_token=TOKEN)
Path("static/images").mkdir(parents=True, exist_ok=True)

db = sqlite3.connect(DB_FILE)
try:
    data = db.execute("SELECT * FROM prompts").fetchall()
    db.close()
    print("DB DATA", data)
except sqlite3.OperationalError:
    db.execute('CREATE TABLE prompts (guess TEXT, correct TEXT)')
    db.commit()

# extract images and prompts from dataset and save to dis
data = {}
for row in dataset['train']:
    prompt = dataset['train'].features['label'].int2str(row['label'])
    image = row['image']
    hash = uuid.uuid4().hex
    image_file = Path(f'static/images/{hash}.jpg')
    image_compress = image.resize((136, 136), Image.Resampling.LANCZOS)
    image_compress.save(image_file, optimize=True, quality=95)
    if prompt not in data:
        data[prompt] = []
    data[prompt].append(str(image_file))

with open('static/data.json', 'w') as f:
    json.dump(data, f)


def update_repository():
    with sqlite3.connect(DB_FILE) as db:
        db.row_factory = sqlite3.Row
        result = db.execute("SELECT * FROM prompts").fetchall()
        data = [dict(row) for row in result]

    with open('./data/data.json', 'w') as f:
        json.dump(data, f, separators=(',', ':'))

    print("Updating repository")
    repo.git_pull()
    repo.push_to_hub(blocking=False)


@app.route('/')
def index():
    return app.send_static_file('index.html')


@app.route('/data')
def getdata():
    return app.send_static_file('data.json')


@app.route('/prompt', methods=['POST', 'GET'])
def create():
    if request.method == 'POST':
        try:
            data = request.get_json()
            guess = data['guess']
            correct = data['correct']
            with sqlite3.connect(DB_FILE) as db:
                db.execute(
                    'INSERT INTO prompts (guess, correct) VALUES (?, ?)', (guess, correct))
                db.commit()
            return 'OK', 200
        except:
            return 'Missing guess or correct', 400

if __name__ == '__main__':
    scheduler = APScheduler()
    scheduler.add_job(id = 'Update Dataset Repository', func = update_repository, trigger = 'interval', seconds = 60)
    scheduler.start()
    app.run(host='0.0.0.0',  port=int(
        os.environ.get('PORT', 7860)), debug=True)