b
File size: 2,863 Bytes
6437aa7
e05fc54
 
 
 
 
 
 
f9b29bc
e05fc54
f9b29bc
6437aa7
 
 
 
 
be673e5
 
 
 
 
 
 
f9b29bc
 
e05fc54
f9b29bc
 
 
 
 
 
e05fc54
f9b29bc
 
 
 
 
 
be673e5
 
f9b29bc
 
 
be673e5
f9b29bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cdcff4
be673e5
f9b29bc
9cdcff4
f9b29bc
 
 
 
 
 
 
 
 
 
 
9cdcff4
 
f9b29bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os

# Install Flask if not already installed
return_code = os.system('pip install flask')
if return_code != 0:
    raise RuntimeError("Failed to install Flask")

import gradio as gr
from random import randint
from all_models import models
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image

app = Flask(__name__)

def load_model(model_name):
    try:
        # Load model within Gradio context
        return gr.Interface.load(f'models/{model_name}')
    except Exception as error:
        print(f"Error loading model {model_name}: {error}")
        return gr.Interface(lambda txt: None, ['text'], ['image'])

num_models = 6
default_models = models[:num_models]

def extend_choices(choices):
    return choices + (num_models - len(choices)) * ['NA']

def update_imgbox(choices):
    choices_plus = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]

def gen_fn(model_str, prompt, negative_prompt=None):
    if model_str == 'NA':
        return None
    noise = str('')  # Optional: str(randint(0, 99999999999))
    try:
        # Load model within Gradio context
        model = load_model(model_str)
        full_prompt = f'{prompt} {noise}'
        if negative_prompt:
            full_prompt += f' -{negative_prompt}'
        result = model(full_prompt)
        # Check if result is an image or a file path
        if isinstance(result, str):  # Assuming result might be a file path
            if os.path.exists(result):
                return Image.open(result)
            else:
                print(f"File path not found: {result}")
                return None
        elif isinstance(result, Image.Image):
            return result
        else:
            print("Result is not an image:", type(result))
            return None
    except Exception as e:
        print("Error generating image:", e)
        return None

@app.route('/', methods=['GET'])
def home():
    prompt = request.args.get('prompt', '')
    model = request.args.get('model', default_models[0] if default_models else 'NA')
    negative_prompt = request.args.get('Nprompt', None)
    
    if not model:
        return f'Invalid model: {model}', 400
    
    if prompt:
        # Generate the image
        image = gen_fn(model, prompt, negative_prompt)
        if isinstance(image, Image.Image):  # Ensure the result is a PIL image
            # Save image to BytesIO object
            img_io = BytesIO()
            image.save(img_io, format='PNG')
            img_io.seek(0)
            return send_file(img_io, mimetype='image/png', as_attachment=False)
        return 'Failed to generate image.', 500
    return 'Please provide a "prompt" query parameter in the URL.', 400

if __name__ == '__main__':
    # Launch Flask app
    app.run(host='0.0.0.0', port=7860)  # Run Flask app