File size: 1,254 Bytes
696db06 d798c5a 2eb1363 0250d76 8647971 2eb1363 967efaf 0250d76 ee6e9e2 efbaaff 8647971 029e32c b9f4a2a 2eb1363 efbaaff 2eb1363 fac22d0 fce8087 7fb8c56 696db06 fce8087 fac22d0 b9f4a2a 696db06 efbaaff 14c2d1c 696db06 14c2d1c 696db06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from flask import url_for
from diffusers import StableDiffusionPipeline
import torch
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from auth_token import auth_token
app = FastAPI()
app.add_middleware( # add the middleware
CORSMiddleware,
allow_credentials=True, # allow credentials
allow_origins=["*"], # allow all origins
allow_methods=["*"], # allow all methods
allow_headers=["*"], # allow all headers
)
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
def dummy(images, **kwargs):
return images, False
pipe.safety_checker = dummy
@app.get("/")
def hello():
return "Hello, I'm Linlada"
@app.get("/gen/{prompt}")
def generate_image(prompt: str):
image = pipe(prompt,
guidance_scale=8.5 # how strict to follow the prompt
).images[0]
# Save the image
image.save('static/image.png')
# do something with the generated image
image_data = image.tobytes().hex()
image_url = url_for('static', filename='image.png')
return {'image_data': image_data, 'image_url': image_url}
|