colt12 commited on
Commit
a31ed9b
·
verified ·
1 Parent(s): a1add26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -1,50 +1,44 @@
1
- python
2
- import io
3
- import base64
 
4
  import torch
5
- from PIL import Image
6
- from flask import Flask, request, jsonify
7
- from diffusers import StableDiffusionPipeline
8
 
9
- app = Flask(__name__)
10
 
11
  # Load the model
12
  model_name = "colt12/maxcushion"
13
- try:
14
- print("Loading model...")
15
- pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
16
- pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
17
- print("Model loaded successfully.")
18
- except Exception as e:
19
- print(f"Error loading model: {str(e)}")
20
- raise
21
-
22
- def generate_image(prompt):
23
- with torch.no_grad():
24
- image = pipe(prompt).images[0]
25
-
26
- buffered = io.BytesIO()
27
- image.save(buffered, format="PNG")
28
- image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
29
-
30
- return image_base64
31
 
32
- @app.route('/', methods=['GET'])
33
- def home():
34
- return "Welcome to the Image Generation API. Use the /generate endpoint to generate images from prompts."
 
 
35
 
36
- @app.route('/generate', methods=['POST'])
37
- def run():
38
- if 'prompt' not in request.json:
39
- return jsonify({"error": "No prompt provided"}), 400
40
-
41
- prompt = request.json['prompt']
42
-
43
  try:
44
- result = generate_image(prompt)
45
- return jsonify({"image": result})
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- return jsonify({"error": str(e)}), 500
48
 
49
- if __name__ == "__main__":
50
- app.run(host='0.0.0.0', port=5000)
 
 
1
+ from typing import Dict, List
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from diffusers import StableDiffusionXLPipeline
5
  import torch
6
+ import base64
7
+ from io import BytesIO
 
8
 
9
+ app = FastAPI()
10
 
11
  # Load the model
12
  model_name = "colt12/maxcushion"
13
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
14
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ class Item(BaseModel):
17
+ prompt: str
18
+ negative_prompt: str = ""
19
+ num_inference_steps: int = 30
20
+ guidance_scale: float = 7.5
21
 
22
+ @app.post("/generate")
23
+ async def generate(item: Item) -> Dict[str, str]:
 
 
 
 
 
24
  try:
25
+ # Generate the image
26
+ image = pipe(
27
+ prompt=item.prompt,
28
+ negative_prompt=item.negative_prompt,
29
+ num_inference_steps=item.num_inference_steps,
30
+ guidance_scale=item.guidance_scale
31
+ ).images[0]
32
+
33
+ # Convert to base64
34
+ buffered = BytesIO()
35
+ image.save(buffered, format="PNG")
36
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
37
+
38
+ return {"image": image_base64}
39
  except Exception as e:
40
+ raise HTTPException(status_code=500, detail=str(e))
41
 
42
+ @app.get("/")
43
+ async def root():
44
+ return {"message": "SDXL Image Generation API"}