File size: 4,523 Bytes
968109b
605659c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d14be
 
 
 
605659c
 
 
 
 
 
968109b
 
 
 
 
 
 
 
 
 
 
605659c
968109b
 
 
 
 
 
 
 
 
 
 
227ac6e
e6a3d12
605659c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946624d
605659c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from huggingface_hub import HfApi
import gradio as gr
from urllib.parse import urlparse
import requests
import time
import os

from utils.gradio_helpers import parse_outputs, process_outputs

inputs = []
inputs.append(gr.Textbox(
    label="Model Image", info='''Clear picture of the model'''
))

inputs.append(gr.Textbox(
    label="Garment Image", info='''Clear picture of upper body garment'''
))

inputs.append(gr.Textbox(
    label="Person Mask", info='''Mask of the person's upper body'''
))

inputs.append(gr.Slider(
    label="Steps", info='''Inference steps''', value=20,
    minimum=1, maximum=40, step=1,
))

inputs.append(gr.Slider(
    label="Guidance Scale", info='''Guidance scale''', value=2,
    minimum=1, maximum=5
))

inputs.append(gr.Number(
    label="Seed", info='''Seed''', value=0
))

inputs.append(gr.Slider(
    label="Num Samples", info='''Number of samples''', value=1,
    minimum=1, maximum=4, step=1,
))

names = ['model_image', 'garment_image', 'person_mask', 'steps', 'guidance_scale', 'seed', 'num_samples']

outputs = []
outputs.append(gr.Image())
outputs.append(gr.Image())
outputs.append(gr.Image())
outputs.append(gr.Image())

expected_outputs = len(outputs)
def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
    headers = {'Content-Type': 'application/json'}

    payload = {"input": {}}

    # TODO: extract the Bearer access token from the request
    if not request:
        raise gr.Error("The submission failed!")

    print("Request headers dictionary:", request.headers)

    try:
        authorization = request.headers["Authorization"]
    except KeyError:
        raise gr.Error("Missing authorization in the headers")
    
    # Extract the token part from the authorization
    try:
        bearer, token = authorization.split(" ")
    except ValueError:
        raise gr.Error("Invalid format for Authorization header. It should be 'Bearer <token>'")

    try:
        hf_api = HfApi(token=token)
        userInfo = hf_api.whoami(token)
        if not userInfo:
            raise gr.Error("The provider API key is invalid!")
    except Exception as err:
        raise gr.Error("The provider API key is invalid!")
    
    base_url = "http://0.0.0.0:7860"
    for i, key in enumerate(names):
        value = args[i]
        if value and (os.path.exists(str(value))):
            value = f"{base_url}/file=" + value
        if value is not None and value != "":
            payload["input"][key] = value

    response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
    
    if response.status_code == 201:
        follow_up_url = response.json()["urls"]["get"]
        response = requests.get(follow_up_url, headers=headers)
        while response.json()["status"] != "succeeded":
            if response.json()["status"] == "failed":
                raise gr.Error("The submission failed!")
            response = requests.get(follow_up_url, headers=headers)
            time.sleep(1)
    if response.status_code == 200:
        json_response = response.json()
        #If the output component is JSON return the entire output response 
        if(outputs[0].get_config()["name"] == "json"):
            return json_response["output"]
        predict_outputs = parse_outputs(json_response["output"])
        processed_outputs = process_outputs(predict_outputs)
        difference_outputs = expected_outputs - len(processed_outputs)
        # If less outputs than expected, hide the extra ones
        if difference_outputs > 0:
            extra_outputs = [gr.update(visible=False)] * difference_outputs
            processed_outputs.extend(extra_outputs)
        # If more outputs than expected, cap the outputs to the expected number
        elif difference_outputs < 0:
            processed_outputs = processed_outputs[:difference_outputs]
        
        return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
    else:
        if(response.status_code == 409):
            raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
        raise gr.Error(f"The submission failed! Error: {response.status_code}")

title = "Demo for oot_diffusion_with_mask cog image by jbilcke"
model_description = "Don't mind me :) this is just a fork of viktorfa/oot_diffusion_with_mask"

app = gr.Interface(
    fn=predict,
    inputs=inputs,
    outputs=outputs,
    title=title,
    description=model_description,
    allow_flagging="never",
)
app.launch()