File size: 2,473 Bytes
6a908d9
b47d425
 
 
 
6a908d9
 
b47d425
 
 
 
6a908d9
 
b47d425
 
6a908d9
b47d425
 
 
 
 
 
 
 
 
6a908d9
b47d425
6a908d9
b47d425
 
 
 
 
6a908d9
 
 
b47d425
6a908d9
 
 
 
b47d425
6a908d9
c144445
 
313a5e9
 
6a908d9
 
 
 
c6e8517
2300903
c144445
c1a29f2
 
 
 
2300903
c6e8517
6a908d9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import requests
import base64
from PIL import Image
from io import BytesIO


def decode_base64_image(image_string):
    base64_image = base64.b64decode(image_string)
    buffer = BytesIO(base64_image)
    return Image.open(buffer)


def inference(prompt, guidance_scale, num_inference_steps):
    api_url = 'https://a02q342s5b.execute-api.us-east-2.amazonaws.com/reinvent-demo-inf2-sm-20231114'

    prompt_input_one = {
        "prompt": prompt,
        "parameters": {
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale,
            "seed": -1
        },
        "endpoint": "huggingface-pytorch-inference-neuronx-2023-11-14-21-22-10-388"
    }

    response_one = requests.post(api_url, json=prompt_input_one)

    if response_one.status_code == 200:
        result_one = response_one.json()
        return decode_base64_image(result_one["generated_images"][0])
    else:
        return None


def app():
    return gr.Interface(inference,
                        [gr.Textbox(
                            label="Prompt",
                            info="Enter your prompt",
                            lines=3,
                            value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
                        ),
                            gr.Slider(2, 20, value=15, step=1, label="Guidance Scale"),
                            gr.Slider(1, 50, value=20, step=1, label="Inference steps")
                        ],

                        gr.Image(type="pil",
                                 height=512,
                                 width=512
                                 )
                        , allow_flagging='never', title='Gen Image',
                        examples=[
                            ["A bustling metropolis skyline of towering skyscrapers, illuminated by the neon glow of futuristic advertisements and hover vehicles zipping through the airways, casting dynamic shadows on sleek, reflective surfaces below, 8k", 7, 20],
                            ["Design an image capturing the essence of 'timeless wonder' in a mystical forest setting.",
                             7, 20],
                            ["Visualize the emotions evoked by the words 'bittersweet symphony' in a unique artwork.",
                             15, 20],
                        ]
                        )


if __name__ == "__main__":
    app().launch()