multimodalart HF staff commited on
Commit
4ef6d61
1 Parent(s): 322db57

Fix some issues

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -5,6 +5,9 @@ from PIL import Image
5
  import json
6
  import os
7
  import logging
 
 
 
8
 
9
  logging.basicConfig(level=logging.DEBUG)
10
 
@@ -31,27 +34,39 @@ def run_lora(prompt, selected_state, progress=gr.Progress(track_tqdm=True)):
31
  selected_lora = loras[selected_lora_index]
32
  api_url = f"https://api-inference.huggingface.co/models/{selected_lora['repo']}"
33
  trigger_word = selected_lora["trigger_word"]
34
- token = os.getenv("API_TOKEN")
35
  payload = {"inputs": f"{prompt} {trigger_word}"}
36
 
37
- headers = {"Authorization": f"Bearer {token}"}
38
 
39
  # Add a print statement to display the API request
40
  print(f"API Request: {api_url}")
41
- print(f"API Headers: {headers}")
42
  print(f"API Payload: {payload}")
43
-
44
- response = requests.post(api_url, headers=headers, json=payload)
45
- if response.status_code == 200:
46
- return Image.open(io.BytesIO(response.content))
47
- else:
48
- logging.error(f"API Error: {response.status_code}")
49
- raise gr.Error("API Error: Unable to fetch the image.") # Raise a Gradio error here
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
 
53
  with gr.Blocks(css="custom.css") as app:
54
- title = gr.HTML("<h1>LoRA the Explorer</h1>")
55
  selected_state = gr.State()
56
  with gr.Row():
57
  gallery = gr.Gallery(
@@ -71,11 +86,16 @@ with gr.Blocks(css="custom.css") as app:
71
  update_selection,
72
  outputs=[prompt, selected_state]
73
  )
 
 
 
 
 
74
  button.click(
75
  fn=run_lora,
76
  inputs=[prompt, selected_state],
77
  outputs=[result]
78
  )
79
 
80
- app.queue(max_size=20)
81
- app.launch()
 
5
  import json
6
  import os
7
  import logging
8
+ import math
9
+ from tqdm import tqdm
10
+ import time
11
 
12
  logging.basicConfig(level=logging.DEBUG)
13
 
 
34
  selected_lora = loras[selected_lora_index]
35
  api_url = f"https://api-inference.huggingface.co/models/{selected_lora['repo']}"
36
  trigger_word = selected_lora["trigger_word"]
37
+ #token = os.getenv("API_TOKEN")
38
  payload = {"inputs": f"{prompt} {trigger_word}"}
39
 
40
+ #headers = {"Authorization": f"Bearer {token}"}
41
 
42
  # Add a print statement to display the API request
43
  print(f"API Request: {api_url}")
44
+ #print(f"API Headers: {headers}")
45
  print(f"API Payload: {payload}")
46
+
47
+ error_count = 0
48
+ pbar = tqdm(total=None, desc="Loading model")
49
+ while(True):
50
+ response = requests.post(api_url, json=payload)
51
+ if response.status_code == 200:
52
+ return Image.open(io.BytesIO(response.content))
53
+ elif response.status_code == 503:
54
+ #503 is triggered when the model is doing cold boot. It also gives you a time estimate from when the model is loaded but it is not super precise
55
+ time.sleep(1)
56
+ pbar.update(1)
57
+ elif response.status_code == 500 and error_count < 5:
58
+ print(response.content)
59
+ time.sleep(1)
60
+ error_count += 1
61
+ continue
62
+ else:
63
+ logging.error(f"API Error: {response.status_code}")
64
+ raise gr.Error("API Error: Unable to fetch the image.") # Raise a Gradio error here
65
 
66
 
67
 
68
  with gr.Blocks(css="custom.css") as app:
69
+ title = gr.Markdown("# artificialguybr LoRA portfolio")
70
  selected_state = gr.State()
71
  with gr.Row():
72
  gallery = gr.Gallery(
 
86
  update_selection,
87
  outputs=[prompt, selected_state]
88
  )
89
+ prompt.submit(
90
+ fn=run_lora,
91
+ inputs=[prompt, selected_state],
92
+ outputs=[result]
93
+ )
94
  button.click(
95
  fn=run_lora,
96
  inputs=[prompt, selected_state],
97
  outputs=[result]
98
  )
99
 
100
+ app.queue(max_size=20, concurrency_count=5)
101
+ app.launch()