radames HF staff commited on
Commit
1d06c9e
β€’
1 Parent(s): e7c98f4
app.py CHANGED
@@ -12,10 +12,10 @@ import sqlite3
12
  import subprocess
13
  from jsonschema import ValidationError
14
 
15
- MODE = os.environ.get('FLASK_ENV', 'production')
16
- IS_DEV = MODE == 'development'
17
- app = Flask(__name__, static_url_path='/static')
18
- app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
19
 
20
  schema = {
21
  "type": "object",
@@ -30,29 +30,28 @@ schema = {
30
  "properties": {
31
  "colors": {
32
  "type": "array",
33
- "items": {
34
- "type": "string"
35
- },
36
  "maxItems": 5,
37
- "minItems": 5
38
  },
39
- "imgURL": {"type": "string"}}
40
- }
41
- }
 
42
  },
43
  "minProperties": 2,
44
- "maxProperties": 2
45
  }
46
 
47
  CORS(app)
48
 
49
  DB_FILE = Path("./data.db")
50
- TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN')
51
  repo = Repository(
52
  local_dir="data",
53
  repo_type="dataset",
54
  clone_from="huggingface-projects/color-palettes-sd",
55
- use_auth_token=TOKEN
56
  )
57
  repo.git_pull()
58
  # copy db on db to local path
@@ -66,12 +65,13 @@ try:
66
  db.close()
67
  except sqlite3.OperationalError:
68
  db.execute(
69
- 'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)')
 
70
  db.commit()
71
 
72
 
73
  def get_db():
74
- db = getattr(g, '_database', None)
75
  if db is None:
76
  db = g._database = sqlite3.connect(DB_FILE)
77
  db.row_factory = sqlite3.Row
@@ -80,7 +80,7 @@ def get_db():
80
 
81
  @app.teardown_appcontext
82
  def close_connection(exception):
83
- db = getattr(g, '_database', None)
84
  if db is not None:
85
  db.close()
86
 
@@ -93,45 +93,60 @@ def update_repository():
93
  with sqlite3.connect("./data/data.db") as db:
94
  db.row_factory = sqlite3.Row
95
  palettes = db.execute("SELECT * FROM palettes").fetchall()
96
- data = [{'id': row['id'], 'data': json.loads(
97
- row['data']), 'created_at': row['created_at']} for row in palettes]
 
 
 
 
 
 
98
 
99
- with open('./data/data.json', 'w') as f:
100
- json.dump(data, f, separators=(',', ':'))
101
 
102
  print("Updating repository")
103
  subprocess.Popen(
104
- "git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True)
 
 
 
105
  repo.push_to_hub(blocking=False)
106
 
107
 
108
- @app.route('/')
109
  def index():
110
- return app.send_static_file('index.html')
111
 
112
 
113
- @app.route('/force_push')
114
  def push():
115
- if (request.headers['token'] == TOKEN):
116
  update_repository()
117
- return jsonify({'success': True})
118
  else:
119
  return "Error", 401
120
 
121
 
122
  def getAllData():
123
  palettes = get_db().execute("SELECT * FROM palettes").fetchall()
124
- data = [{'id': row['id'], 'data': json.loads(
125
- row['data']), 'created_at': row['created_at']} for row in palettes]
 
 
 
 
 
 
126
  return data
127
 
128
 
129
- @app.route('/data')
130
  def getdata():
131
  return jsonify(getAllData())
132
 
133
 
134
- @app.route('/new_palette', methods=['POST'])
135
  @expects_json(schema)
136
  def create():
137
  data = g.data
@@ -146,18 +161,26 @@ def create():
146
  def bad_request(error):
147
  if isinstance(error.description, ValidationError):
148
  original_error = error.description
149
- return jsonify({'error': original_error.message}), 400
150
  return error
151
 
152
 
153
- if __name__ == '__main__':
154
  if not IS_DEV:
155
  print("Starting scheduler -- Running Production")
156
  scheduler = APScheduler()
157
- scheduler.add_job(id='Update Dataset Repository',
158
- func=update_repository, trigger='interval', hours=1)
 
 
 
 
159
  scheduler.start()
160
  else:
161
  print("Not Starting scheduler -- Running Development")
162
- app.run(host='0.0.0.0', port=int(
163
- os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV)
 
 
 
 
 
12
  import subprocess
13
  from jsonschema import ValidationError
14
 
15
+ MODE = os.environ.get("FLASK_ENV", "production")
16
+ IS_DEV = MODE == "development"
17
+ app = Flask(__name__, static_url_path="/static")
18
+ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False
19
 
20
  schema = {
21
  "type": "object",
 
30
  "properties": {
31
  "colors": {
32
  "type": "array",
33
+ "items": {"type": "string"},
 
 
34
  "maxItems": 5,
35
+ "minItems": 5,
36
  },
37
+ "imgURL": {"type": "string"},
38
+ },
39
+ },
40
+ },
41
  },
42
  "minProperties": 2,
43
+ "maxProperties": 2,
44
  }
45
 
46
  CORS(app)
47
 
48
  DB_FILE = Path("./data.db")
49
+ TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN")
50
  repo = Repository(
51
  local_dir="data",
52
  repo_type="dataset",
53
  clone_from="huggingface-projects/color-palettes-sd",
54
+ use_auth_token=TOKEN,
55
  )
56
  repo.git_pull()
57
  # copy db on db to local path
 
65
  db.close()
66
  except sqlite3.OperationalError:
67
  db.execute(
68
+ "CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)"
69
+ )
70
  db.commit()
71
 
72
 
73
  def get_db():
74
+ db = getattr(g, "_database", None)
75
  if db is None:
76
  db = g._database = sqlite3.connect(DB_FILE)
77
  db.row_factory = sqlite3.Row
 
80
 
81
  @app.teardown_appcontext
82
  def close_connection(exception):
83
+ db = getattr(g, "_database", None)
84
  if db is not None:
85
  db.close()
86
 
 
93
  with sqlite3.connect("./data/data.db") as db:
94
  db.row_factory = sqlite3.Row
95
  palettes = db.execute("SELECT * FROM palettes").fetchall()
96
+ data = [
97
+ {
98
+ "id": row["id"],
99
+ "data": json.loads(row["data"]),
100
+ "created_at": row["created_at"],
101
+ }
102
+ for row in palettes
103
+ ]
104
 
105
+ with open("./data/data.json", "w") as f:
106
+ json.dump(data, f, separators=(",", ":"))
107
 
108
  print("Updating repository")
109
  subprocess.Popen(
110
+ "git add . && git commit --amend -m 'update' && git push --force",
111
+ cwd="./data",
112
+ shell=True,
113
+ )
114
  repo.push_to_hub(blocking=False)
115
 
116
 
117
+ @app.route("/")
118
  def index():
119
+ return app.send_static_file("index.html")
120
 
121
 
122
+ @app.route("/force_push")
123
  def push():
124
+ if request.headers["token"] == TOKEN:
125
  update_repository()
126
+ return jsonify({"success": True})
127
  else:
128
  return "Error", 401
129
 
130
 
131
  def getAllData():
132
  palettes = get_db().execute("SELECT * FROM palettes").fetchall()
133
+ data = [
134
+ {
135
+ "id": row["id"],
136
+ "data": json.loads(row["data"]),
137
+ "created_at": row["created_at"],
138
+ }
139
+ for row in palettes
140
+ ]
141
  return data
142
 
143
 
144
+ @app.route("/data")
145
  def getdata():
146
  return jsonify(getAllData())
147
 
148
 
149
+ @app.route("/new_palette", methods=["POST"])
150
  @expects_json(schema)
151
  def create():
152
  data = g.data
 
161
  def bad_request(error):
162
  if isinstance(error.description, ValidationError):
163
  original_error = error.description
164
+ return jsonify({"error": original_error.message}), 400
165
  return error
166
 
167
 
168
+ if __name__ == "__main__":
169
  if not IS_DEV:
170
  print("Starting scheduler -- Running Production")
171
  scheduler = APScheduler()
172
+ scheduler.add_job(
173
+ id="Update Dataset Repository",
174
+ func=update_repository,
175
+ trigger="interval",
176
+ hours=1,
177
+ )
178
  scheduler.start()
179
  else:
180
  print("Not Starting scheduler -- Running Development")
181
+ app.run(
182
+ host="0.0.0.0",
183
+ port=int(os.environ.get("PORT", 7860)),
184
+ debug=True,
185
+ use_reloader=IS_DEV,
186
+ )
frontend/src/lib/Palette.svelte CHANGED
@@ -12,7 +12,17 @@
12
 
13
  $: prompt = promptData?.prompt;
14
  $: colors = promptData?.images[seletecdImage]?.colors.map((e) => d3.rgb(e)) || [];
15
- $: imageSrc = promptData?.images[seletecdImage]?.imgURL;
 
 
 
 
 
 
 
 
 
 
16
  let isCopying = false;
17
 
18
  async function copyStringToClipboard(text: string) {
 
12
 
13
  $: prompt = promptData?.prompt;
14
  $: colors = promptData?.images[seletecdImage]?.colors.map((e) => d3.rgb(e)) || [];
15
+ $: imageSrc = fixLink(promptData?.images[seletecdImage]?.imgURL);
16
+
17
+ function fixLink(link: string) {
18
+ if (link.includes('s3.amazonaws.com')) {
19
+ return link.replace(
20
+ 's3.amazonaws.com/moonup/production/uploads/noauth',
21
+ 'cdn-uploads.huggingface.co/production/uploads/noauth'
22
+ );
23
+ }
24
+ return link;
25
+ }
26
  let isCopying = false;
27
 
28
  async function copyStringToClipboard(text: string) {
frontend/src/routes/+page.svelte CHANGED
@@ -3,7 +3,7 @@
3
  import type { ColorsPrompt, ColorsImage } from '$lib/types';
4
  import { randomSeed, extractPalette, uploadImage } from '$lib/utils';
5
  import { isLoading, loadingState } from '$lib/store';
6
- import { PUBLIC_WS_ENDPOINT, PUBLIC_API} from '$env/static/public';
7
  import Pallette from '$lib/Palette.svelte';
8
  import ArrowRight from '$lib/ArrowRight.svelte';
9
  import ArrowLeft from '$lib/ArrowLeft.svelte';
@@ -81,12 +81,12 @@
81
  const sessionHash = crypto.randomUUID();
82
 
83
  const hashpayload = {
84
- fn_index: 2,
85
  session_hash: sessionHash
86
  };
87
 
88
  const datapayload = {
89
- data: [_prompt]
90
  };
91
 
92
  const websocket = new WebSocket(PUBLIC_WS_ENDPOINT);
 
3
  import type { ColorsPrompt, ColorsImage } from '$lib/types';
4
  import { randomSeed, extractPalette, uploadImage } from '$lib/utils';
5
  import { isLoading, loadingState } from '$lib/store';
6
+ import { PUBLIC_WS_ENDPOINT, PUBLIC_API } from '$env/static/public';
7
  import Pallette from '$lib/Palette.svelte';
8
  import ArrowRight from '$lib/ArrowRight.svelte';
9
  import ArrowLeft from '$lib/ArrowLeft.svelte';
 
81
  const sessionHash = crypto.randomUUID();
82
 
83
  const hashpayload = {
84
+ fn_index: 3,
85
  session_hash: sessionHash
86
  };
87
 
88
  const datapayload = {
89
+ data: [_prompt, '', 9]
90
  };
91
 
92
  const websocket = new WebSocket(PUBLIC_WS_ENDPOINT);