File size: 2,582 Bytes
781570f
 
f48b8b0
781570f
 
 
5091b36
7f6bab7
d8e3b20
f48b8b0
80b3e90
 
 
 
 
 
e9cd1d1
7f6bab7
5a42448
781570f
5091b36
 
 
 
 
781570f
e9cd1d1
7f6bab7
 
 
 
781570f
 
 
 
 
 
5a42448
781570f
 
 
 
 
 
 
 
 
 
 
 
 
 
5091b36
 
 
 
 
 
 
 
 
 
 
 
 
 
781570f
 
5091b36
 
 
 
 
 
1ff8dbd
e9cd1d1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from diffusers import AutoPipelineForImage2Image
import torch
import os
import numpy as np
from PIL import Image
from diffusers.utils import load_image, make_image_grid
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import io

# Set environment variable to avoid fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear any unused GPU memory
torch.cuda.empty_cache()

app = Flask(__name__)
CORS(app)

print('loading models...')
# Load the image-to-image pipeline from Hugging Face
pipe = AutoPipelineForImage2Image.from_pretrained("RunDiffusion/Juggernaut-X-v10", torch_dtype=torch.float16).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_vae_tiling()  # Improve performance on large images
pipe.enable_vae_slicing()  # Improve performance on large batches
print('loaded models...')

@app.route('/')
def hello():
    return {"Goes Wrong": "Keeping it real"}

@app.route('/run_inference', methods=['POST'])
def run_inference():
    data = request.get_json()

    if 'url' not in data:
        return jsonify({"error": "No imageurl provided"}), 400

    # base64_image = data['base64_image']
    prompt = data.get('prompt', 'fleece hoodie, front zip, abstract pattern, GAP logo, high quality, photo')
    negative_prompt = data.get('negative_prompt', 'low quality, bad quality, sketches, hanger')
    guidance_scale = float(data.get('guidance_scale', 7))
    num_images = int(data.get('num_images', 2))

    url =  data.get('url', 'https://storage.googleapis.com/sketch-bucket/dresstest2.PNG')
    sketch = load_image(url)
    print(f'Loaded image URL: {url}')

    # testing
    # prompt = "long waist dress, puffed sleeves, fringes on sleeve and hem, high quality, photo"
    # negative_prompt = "low quality, bad quality, sketches, hanger"
    # guidance_scale = 7

    with torch.inference_mode():
        images = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=sketch,
            num_inference_steps=35,
            guidance_scale=guidance_scale,
            strength=0.5,
            generator=torch.manual_seed(69),
            num_images_per_prompt=num_images,
        ).images

    grid = make_image_grid(images, rows=1, cols=num_images)
    # images[0].save('output.png')

    # Save the generated grid to a BytesIO object
    img_byte_arr = io.BytesIO()
    grid.save(img_byte_arr, format='PNG')
    img_byte_arr.seek(0)

    return send_file(img_byte_arr, mimetype='image/png')

if __name__ == '__main__':
    app.run(debug=True)