thefreeham commited on
Commit
5a6f45f
1 Parent(s): c652d97

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import os
4
+ from pathlib import Path
5
+ from io import BytesIO
6
+ import time
7
+
8
+ from flask import Flask, request, jsonify
9
+ from flask_cors import CORS, cross_origin
10
+ from consts import IMAGES_OUTPUT_DIR
11
+ from utils import parse_arg_boolean, parse_arg_dalle_version
12
+ from consts import ModelSize
13
+
14
+ app = Flask(__name__)
15
+ CORS(app)
16
+ print("--> Starting DALL-E Server. This might take up to two minutes.")
17
+
18
+ from dalle_model import DalleModel
19
+ dalle_model = None
20
+
21
+ parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
22
+ parser.add_argument("--port", type=int, default=8000, help = "backend port")
23
+ parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
24
+ parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
25
+ args = parser.parse_args()
26
+
27
+ @app.route("/dalle", methods=["POST"])
28
+ @cross_origin()
29
+ def generate_images_api():
30
+ json_data = request.get_json(force=True)
31
+ text_prompt = json_data["text"]
32
+ num_images = json_data["num_images"]
33
+ generated_imgs = dalle_model.generate_images(text_prompt, num_images)
34
+
35
+ generated_images = []
36
+ if args.save_to_disk:
37
+ dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
38
+ Path(dir_name).mkdir(parents=True, exist_ok=True)
39
+
40
+ for idx, img in enumerate(generated_imgs):
41
+ if args.save_to_disk:
42
+ img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
43
+
44
+ buffered = BytesIO()
45
+ img.save(buffered, format="JPEG")
46
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
47
+ generated_images.append(img_str)
48
+
49
+ print(f"Created {num_images} images from text prompt [{text_prompt}]")
50
+ return jsonify(generated_images)
51
+
52
+
53
+ @app.route("/", methods=["GET"])
54
+ @cross_origin()
55
+ def health_check():
56
+ return jsonify(success=True)
57
+
58
+
59
+ with app.app_context():
60
+ dalle_model = DalleModel(args.model_version)
61
+ dalle_model.generate_images("warm-up", 1)
62
+ print("--> DALL-E Server is up and running!")
63
+ print(f"--> Model selected - DALL-E {args.model_version}")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ app.run(host="0.0.0.0", port=args.port, debug=False)