soiz commited on
Commit
ef72775
·
verified ·
1 Parent(s): e27808c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -80
app.py CHANGED
@@ -1,53 +1,80 @@
1
  from flask import Flask, request, jsonify, send_file, render_template_string, make_response
2
- from deep_translator import GoogleTranslator
3
- from PIL import Image
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
- import random
7
  import io
8
- import os
 
 
9
 
10
  app = Flask(__name__)
11
 
12
- MODEL_NAME = "Ojimi/anime-kawai-diffusion"
13
- MODEL_DIR = "./models/anime-kawai-diffusion" # Directory to store the model
14
 
15
- # Download and load the model at startup
16
- def load_model():
17
- if not os.path.exists(MODEL_DIR):
18
- print(f"Downloading the model {MODEL_NAME}...")
19
- pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
20
- pipeline.save_pretrained(MODEL_DIR)
21
- else:
22
- print(f"Loading model from {MODEL_DIR}...")
23
- pipeline = StableDiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.float16)
24
 
25
- if torch.cuda.is_available():
26
- pipeline.to("cuda")
27
- print("Model loaded on GPU")
28
- else:
29
- print("GPU not available. Running on CPU.")
30
 
31
- return pipeline
 
 
32
 
33
- # Load the model once during startup
34
- pipeline = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # HTML template for the index page
37
  index_html = """
38
  <!DOCTYPE html>
39
  <html lang="ja">
40
- <head>
41
- <title>Kawaii Diffusion</title>
42
- </head>
43
- <body>
44
- <h1>Kawaii Diffusion Image Generator</h1>
45
- <form action="/generate" method="get">
46
- <label for="prompt">Prompt:</label>
47
- <input type="text" id="prompt" name="prompt" required><br><br>
48
- <button type="submit">Generate Image</button>
49
- </form>
50
- </body>
51
  </html>
52
  """
53
 
@@ -55,51 +82,27 @@ index_html = """
55
  def index():
56
  return render_template_string(index_html)
57
 
58
- # Function to generate image locally
59
- def generate_image_locally(prompt, steps=35, cfg_scale=7, width=512, height=512, seed=-1):
60
- # Translate prompt from Russian to English
61
- prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
62
- print(f'Translated prompt: {prompt}')
63
-
64
- # Set a random seed if not provided
65
- generator = torch.manual_seed(seed if seed != -1 else random.randint(1, 1_000_000))
66
-
67
- # Generate the image using the loaded pipeline
68
- image = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator).images[0]
69
- return image
70
-
71
  @app.route('/generate', methods=['GET'])
72
  def generate_image():
73
- try:
74
- prompt = request.args.get("prompt", "")
75
- steps = int(request.args.get("steps", 35))
76
- cfg_scale = float(request.args.get("cfgs", 7))
77
- width = int(request.args.get("width", 512))
78
- height = int(request.args.get("height", 512))
79
- seed = int(request.args.get("seed", -1))
80
-
81
- # Generate the image locally
82
- image = generate_image_locally(prompt, steps, cfg_scale, width, height, seed)
83
-
84
- # Save the image to a BytesIO object
85
- img_bytes = io.BytesIO()
86
- image.save(img_bytes, format='PNG')
87
- img_bytes.seek(0)
88
-
89
- return send_file(img_bytes, mimetype='image/png')
90
- except Exception as e:
91
- return jsonify({"error": str(e)}), 500
92
-
93
- # Content-Security-Policyヘッダーを設定する
94
- @app.after_request
95
- def add_security_headers(response):
96
- response.headers['Content-Security-Policy'] = (
97
- "default-src 'self'; "
98
- "img-src 'self' data:; "
99
- "style-src 'self' 'unsafe-inline'; "
100
- "script-src 'self' 'unsafe-inline'; "
101
- )
102
- return response
103
 
104
  if __name__ == "__main__":
105
- app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify, send_file, render_template_string, make_response
2
+ import requests
 
 
 
 
3
  import io
4
+ import random
5
+ from PIL import Image
6
+ from deep_translator import GoogleTranslator
7
 
8
  app = Flask(__name__)
9
 
10
+ API_URL = "https://api-inference.huggingface.co/models/Ojimi/anime-kawai-diffusion"
11
+ timeout = 3000 # タイムアウトを300秒に設定
12
 
13
+ # Function to query the API and return the generated image
14
+ def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024):
15
+ if not prompt:
16
+ return None, "Prompt is required"
 
 
 
 
 
17
 
18
+ key = random.randint(0, 999)
 
 
 
 
19
 
20
+ # Translate the prompt from Russian to English if necessary
21
+ prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
22
+ print(f'Generation {key} translation: {prompt}')
23
 
24
+ # Add some extra flair to the prompt
25
+ prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
26
+ print(f'Generation {key}: {prompt}')
27
+
28
+ payload = {
29
+ "inputs": prompt,
30
+ "is_negative": False,
31
+ "steps": steps,
32
+ "cfg_scale": cfg_scale,
33
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
34
+ "strength": strength,
35
+ "parameters": {
36
+ "width": width,
37
+ "height": height
38
+ }
39
+ }
40
+
41
+ for attempt in range(3): # 最大3回の再試行
42
+ try:
43
+ # Authorization header is removed
44
+ response = requests.post(API_URL, json=payload, timeout=timeout)
45
+ if response.status_code != 200:
46
+ return None, f"Error: Failed to get image. Status code: {response.status_code}, Details: {response.text}"
47
+
48
+ image_bytes = response.content
49
+ image = Image.open(io.BytesIO(image_bytes))
50
+ return image, None
51
+ except requests.exceptions.Timeout:
52
+ if attempt < 2: # 最後の試行でない場合は再試行
53
+ print("Timeout occurred, retrying...")
54
+ continue
55
+ return None, "Error: The request timed out. Please try again."
56
+ except requests.exceptions.RequestException as e:
57
+ return None, f"Request Exception: {str(e)}"
58
+ except Exception as e:
59
+ return None, f"Error when trying to open the image: {e}"
60
+
61
+ # Content-Security-Policyヘッダーを設定するための関数
62
+ @app.after_request
63
+ def add_security_headers(response):
64
+ response.headers['Content-Security-Policy'] = (
65
+ "default-src 'self'; "
66
+ "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
67
+ "img-src 'self' data:; "
68
+ "style-src 'self' 'unsafe-inline'; "
69
+ "script-src 'self' 'unsafe-inline'; "
70
+ )
71
+ return response
72
 
73
  # HTML template for the index page
74
  index_html = """
75
  <!DOCTYPE html>
76
  <html lang="ja">
77
+ kawai diffusion
 
 
 
 
 
 
 
 
 
 
78
  </html>
79
  """
80
 
 
82
  def index():
83
  return render_template_string(index_html)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @app.route('/generate', methods=['GET'])
86
  def generate_image():
87
+ prompt = request.args.get("prompt", "")
88
+ negative_prompt = request.args.get("negative_prompt", "")
89
+ steps = int(request.args.get("steps", 35))
90
+ cfg_scale = float(request.args.get("cfgs", 7))
91
+ sampler = request.args.get("sampler", "DPM++ 2M Karras")
92
+ strength = float(request.args.get("strength", 0.7))
93
+ seed = int(request.args.get("seed", -1))
94
+ width = int(request.args.get("width", 1024))
95
+ height = int(request.args.get("height", 1024))
96
+
97
+ image, error = query(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height)
98
+
99
+ if error:
100
+ return jsonify({"error": error}), 400
101
+
102
+ img_bytes = io.BytesIO()
103
+ image.save(img_bytes, format='PNG')
104
+ img_bytes.seek(0)
105
+ return send_file(img_bytes, mimetype='image/png')
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  if __name__ == "__main__":
108
+ app.run(host='0.0.0.0', port=7860)