remiai3 commited on
Commit
c6498ec
·
verified ·
1 Parent(s): 3eabe33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -84
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
- from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
4
  import torch
5
  import os
6
  from PIL import Image
@@ -9,106 +9,114 @@ import time
9
  from accelerate import Accelerator
10
  import logging
11
 
12
- app = Flask(__name__)
 
 
 
 
13
  CORS(app)
14
 
15
- # Configure logging
16
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
  logger = logging.getLogger(__name__)
18
 
19
- # Initialize Accelerator for CPU
20
- accelerator = Accelerator(cpu=True)
21
 
22
- # Model cache
23
  model_cache = {}
24
  model_paths = {
25
- "ssd-1b": "remiai3/ssd-1b",
26
- "sd-v1-5": "remiai3/stable-diffusion-v1-5"
27
- }
28
 
29
- # Image ratio to dimensions (optimized for CPU)
30
  ratio_to_dims = {
31
- "1:1": (256, 256),
32
- "3:4": (192, 256),
33
- "16:9": (256, 144)
34
- }
35
 
36
  def load_model(model_id):
37
- if model_id not in model_cache:
38
- logger.info(f"Loading model {model_id}...")
39
- try:
40
- pipe = StableDiffusionPipeline.from_pretrained(
41
- model_paths[model_id],
42
- torch_dtype=torch.float32,
43
- use_auth_token=os.getenv("HF_TOKEN"),
44
- use_safetensors=True,
45
- low_cpu_mem_usage=True
46
- )
47
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
48
- pipe = accelerator.prepare(pipe)
49
- pipe.enable_attention_slicing()
50
- pipe.enable_sequential_cpu_offload()
51
- model_cache[model_id] = pipe
52
- logger.info(f"Model {model_id} loaded successfully")
53
- except Exception as e:
54
- logger.error(f"Error loading model {model_id}: {str(e)}")
55
- raise
56
- return model_cache[model_id]
 
 
57
 
58
  @app.route('/')
59
  def index():
60
- return app.send_static_file('index.html')
61
 
62
  @app.route('/generate', methods=['POST'])
63
  def generate():
64
- try:
65
- data = request.json
66
- model_id = data.get('model', 'ssd-1b')
67
- prompt = data.get('prompt', '')
68
- ratio = data.get('ratio', '1:1')
69
- num_images = min(int(data.get('num_images', 1)), 4)
70
- guidance_scale = float(data.get('guidance_scale', 7.5))
71
-
72
- if not prompt:
73
- return jsonify({"error": "Prompt is required"}), 400
74
-
75
- if model_id == 'ssd-1b' and num_images > 1:
76
- return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
77
- if model_id == 'ssd-1b' and ratio != '1:1':
78
- return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
79
- if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
80
- return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400
81
-
82
- width, height = ratio_to_dims.get(ratio, (256, 256))
83
- pipe = load_model(model_id)
84
-
85
- images = []
86
- for _ in range(num_images):
87
- image = pipe(
88
- prompt=prompt,
89
- height=height,
90
- width=width,
91
- num_inference_steps=20,
92
- guidance_scale=guidance_scale
93
- ).images[0]
94
- images.append(image)
95
-
96
- output_dir = "outputs"
97
- os.makedirs(output_dir, exist_ok=True)
98
- image_urls = []
99
- for i, img in enumerate(images):
100
- img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
101
- img.save(img_path)
102
- with open(img_path, "rb") as f:
103
- img_data = base64.b64encode(f.read()).decode('utf-8')
104
- image_urls.append(f"data:image/png;base64,{img_data}")
105
- os.remove(img_path)
106
-
107
- return jsonify({"images": image_urls})
108
-
109
- except Exception as e:
110
- logger.error(f"Image generation failed: {str(e)}")
111
- return jsonify({"error": f"Image generation failed: {str(e)}"}), 500
 
 
112
 
113
  if __name__ == '__main__':
114
- app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
4
  import torch
5
  import os
6
  from PIL import Image
 
9
  from accelerate import Accelerator
10
  import logging
11
 
12
+ # Disable GPU detection
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
14
+ torch.set_default_device("cpu")
15
+
16
+ app = Flask(__name__, static_folder='static')
17
  CORS(app)
18
 
19
+ # Configure logging
20
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Initialize Accelerator for CPU
24
+ accelerator = Accelerator(device_placement=False)
25
 
26
+ # Model cache
27
  model_cache = {}
28
  model_paths = {
29
+ "ssd-1b": "remiai3/ssd-1b",
30
+ "sd-v1-5": "remiai3/stable-diffusion-v1-5"
31
+ }
32
 
33
+ # Image ratio to dimensions (optimized for CPU)
34
  ratio_to_dims = {
35
+ "1:1": (256, 256),
36
+ "3:4": (192, 256),
37
+ "16:9": (256, 144)
38
+ }
39
 
40
  def load_model(model_id):
41
+ if model_id not in model_cache:
42
+ logger.info(f"Loading model {model_id}...")
43
+ try:
44
+ pipe = StableDiffusionPipeline.from_pretrained(
45
+ model_paths[model_id],
46
+ torch_dtype=torch.float32,
47
+ use_auth_token=os.getenv("HF_TOKEN"),
48
+ use_safetensors=True,
49
+ low_cpu_mem_usage=True,
50
+ device_map="cpu"
51
+ )
52
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
53
+ pipe = accelerator.prepare(pipe)
54
+ pipe.enable_attention_slicing()
55
+ pipe.enable_sequential_cpu_offload()
56
+ pipe.to("cpu")
57
+ model_cache[model_id] = pipe
58
+ logger.info(f"Model {model_id} loaded successfully")
59
+ except Exception as e:
60
+ logger.error(f"Error loading model {model_id}: {str(e)}")
61
+ raise
62
+ return model_cache[model_id]
63
 
64
  @app.route('/')
65
  def index():
66
+ return app.send_static_file('index.html')
67
 
68
  @app.route('/generate', methods=['POST'])
69
  def generate():
70
+ try:
71
+ data = request.json
72
+ model_id = data.get('model', 'ssd-1b')
73
+ prompt = data.get('prompt', '')
74
+ ratio = data.get('ratio', '1:1')
75
+ num_images = min(int(data.get('num_images', 1)), 4)
76
+ guidance_scale = float(data.get('guidance_scale', 7.5))
77
+
78
+ if not prompt:
79
+ return jsonify({"error": "Prompt is required"}), 400
80
+
81
+ if model_id == 'ssd-1b' and num_images > 1:
82
+ return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
83
+ if model_id == 'ssd-1b' and ratio != '1:1':
84
+ return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
85
+ if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
86
+ return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400
87
+
88
+ width, height = ratio_to_dims.get(ratio, (256, 256))
89
+ pipe = load_model(model_id)
90
+ pipe.to("cpu")
91
+
92
+ images = []
93
+ num_inference_steps = 20 if model_id == 'ssd-1b' else 30
94
+ for _ in range(num_images):
95
+ image = pipe(
96
+ prompt=prompt,
97
+ height=height,
98
+ width=width,
99
+ num_inference_steps=num_inference_steps,
100
+ guidance_scale=guidance_scale
101
+ ).images[0]
102
+ images.append(image)
103
+
104
+ output_dir = "outputs"
105
+ os.makedirs(output_dir, exist_ok=True)
106
+ image_urls = []
107
+ for i, img in enumerate(images):
108
+ img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
109
+ img.save(img_path)
110
+ with open(img_path, "rb") as f:
111
+ img_data = base64.b64encode(f.read()).decode('utf-8')
112
+ image_urls.append(f"data:image/png;base64,{img_data}")
113
+ os.remove(img_path)
114
+
115
+ return jsonify({"images": image_urls})
116
+
117
+ except Exception as e:
118
+ logger.error(f"Image generation failed: {str(e)}")
119
+ return jsonify({"error": f"Image generation failed: {str(e)}"}), 500
120
 
121
  if __name__ == '__main__':
122
+ app.run(host='0.0.0.0', port=7860)