Test2 / app.py
thefreeham's picture
Update app.py
c82c2e4
raw history blame
No virus
2.48 kB
import argparse
import base64
import os
from pathlib import Path
from io import BytesIO
import time
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
from consts import IMAGES_OUTPUT_DIR
from utils import parse_arg_boolean, parse_arg_dalle_version
from consts import ModelSize
import gradio as gr
def greet(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()
app = Flask(__name__)
CORS(app)
print("--> Starting DALL-E Server. This might take up to two minutes.")
from dalle_model import DalleModel
dalle_model = None
parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
parser.add_argument("--port", type=int, default=8000, help = "backend port")
parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
args = parser.parse_args()
@app.route("/dalle", methods=["POST"])
@cross_origin()
def generate_images_api():
json_data = request.get_json(force=True)
text_prompt = json_data["text"]
num_images = json_data["num_images"]
generated_imgs = dalle_model.generate_images(text_prompt, num_images)
generated_images = []
if args.save_to_disk:
dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
Path(dir_name).mkdir(parents=True, exist_ok=True)
for idx, img in enumerate(generated_imgs):
if args.save_to_disk:
img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
generated_images.append(img_str)
print(f"Created {num_images} images from text prompt [{text_prompt}]")
return jsonify(generated_images)
@app.route("/", methods=["GET"])
@cross_origin()
def health_check():
return jsonify(success=True)
with app.app_context():
dalle_model = DalleModel(args.model_version)
dalle_model.generate_images("warm-up", 1)
print("--> DALL-E Server is up and running!")
print(f"--> Model selected - DALL-E {args.model_version}")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=args.port, debug=False)