File size: 1,254 Bytes
696db06
d798c5a
2eb1363
0250d76
 
8647971
2eb1363
967efaf
 
0250d76
 
 
 
 
 
 
 
ee6e9e2
efbaaff
8647971
029e32c
b9f4a2a
2eb1363
 
 
 
efbaaff
2eb1363
fac22d0
fce8087
7fb8c56
696db06
fce8087
fac22d0
 
b9f4a2a
 
 
696db06
 
 
efbaaff
14c2d1c
696db06
14c2d1c
696db06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from flask import url_for
from diffusers import StableDiffusionPipeline
import torch
from fastapi import FastAPI, Response  
from fastapi.middleware.cors import CORSMiddleware
from auth_token import auth_token

app = FastAPI()

app.add_middleware(  # add the middleware
    CORSMiddleware,
    allow_credentials=True,  # allow credentials
    allow_origins=["*"],  # allow all origins
    allow_methods=["*"],  # allow all methods
    allow_headers=["*"],  # allow all headers
)


model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()

def dummy(images, **kwargs):
    return images, False

pipe.safety_checker = dummy

@app.get("/")
def hello():
    return "Hello, I'm Linlada"


@app.get("/gen/{prompt}")
def generate_image(prompt: str):
    image = pipe(prompt,
                  guidance_scale=8.5  # how strict to follow the prompt
                  ).images[0]
    # Save the image
    image.save('static/image.png')

    # do something with the generated image
    image_data = image.tobytes().hex()
    image_url = url_for('static', filename='image.png')

    return {'image_data': image_data, 'image_url': image_url}