Spaces:
Runtime error
Runtime error
menghanxia
commited on
Commit
•
302d824
1
Parent(s):
b3640b9
modified app.py with checkpt downloading
Browse files
app.py
CHANGED
@@ -2,40 +2,12 @@ import gradio as gr
|
|
2 |
import os, requests
|
3 |
from inference import setup_model, colorize_grayscale, predict_anchors
|
4 |
|
5 |
-
|
6 |
-
def download_file_from_google_drive(id, destination):
|
7 |
-
def get_confirm_token(response):
|
8 |
-
for key, value in response.cookies.items():
|
9 |
-
if key.startswith('download_warning'):
|
10 |
-
return value
|
11 |
-
return None
|
12 |
-
|
13 |
-
def save_response_content(response, destination):
|
14 |
-
CHUNK_SIZE = 32768
|
15 |
-
with open(destination, "wb") as f:
|
16 |
-
for chunk in response.iter_content(CHUNK_SIZE):
|
17 |
-
if chunk: # filter out keep-alive new chunks
|
18 |
-
f.write(chunk)
|
19 |
-
|
20 |
-
URL = "https://docs.google.com/uc?export=download"
|
21 |
-
session = requests.Session()
|
22 |
-
response = session.get(URL, params = { 'id' : id }, stream = True)
|
23 |
-
token = get_confirm_token(response)
|
24 |
-
|
25 |
-
if token:
|
26 |
-
params = { 'id' : id, 'confirm' : token }
|
27 |
-
response = session.get(URL, params = params, stream = True)
|
28 |
-
save_response_content(response, destination)
|
29 |
-
|
30 |
-
id = "1J4vB6kG4xBLUUKpXr5IhnSSa4maXgRvQ"
|
31 |
-
destination = "disco-beta.pth.rar"
|
32 |
-
download_file_from_google_drive(id, destination)
|
33 |
os.rename("disco-beta.pth.tar", "./checkpoints/disco-beta.pth.tar")
|
34 |
|
35 |
## step 1: set up model
|
36 |
-
device = "
|
37 |
-
checkpt_path = "
|
38 |
-
assert os.path.exists(checkpt_path), "No checkpoint found!"
|
39 |
colorizer, colorLabeler = setup_model(checkpt_path, device=device)
|
40 |
|
41 |
def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
|
@@ -55,9 +27,11 @@ def switch_states(is_checked):
|
|
55 |
else:
|
56 |
return gr.Image.update(visible=False), gr.Button.update(visible=False)
|
57 |
|
58 |
-
demo = gr.Blocks(title="DISCO
|
59 |
with demo:
|
60 |
-
gr.Markdown(value="""
|
|
|
|
|
61 |
with gr.Row():
|
62 |
with gr.Column(scale=1):
|
63 |
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
|
@@ -78,15 +52,17 @@ with demo:
|
|
78 |
Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
|
79 |
outputs=Image_output)
|
80 |
## guiline
|
81 |
-
gr.Markdown(value="""
|
82 |
-
**Guideline**
|
83 |
-
1.
|
84 |
2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
|
85 |
-
3.
|
86 |
-
- **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (anchor mask will be applied afterward). Finally, click "Colorize" to get the result.
|
87 |
- **Automatic**: click "Colorize" to get the automatically colorized output.
|
88 |
-
|
89 |
-
|
|
|
|
|
90 |
""")
|
91 |
|
92 |
-
demo.launch(server_name='9.134.253.83',server_port=7788)
|
|
|
|
2 |
import os, requests
|
3 |
from inference import setup_model, colorize_grayscale, predict_anchors
|
4 |
|
5 |
+
os.system("wget https://huggingface.co/menghanxia/disco/tree/main/disco-beta.pth.tar")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
os.rename("disco-beta.pth.tar", "./checkpoints/disco-beta.pth.tar")
|
7 |
|
8 |
## step 1: set up model
|
9 |
+
device = "cpu"
|
10 |
+
checkpt_path = "checkpoints/disco-beta.pth.rar"
|
|
|
11 |
colorizer, colorLabeler = setup_model(checkpt_path, device=device)
|
12 |
|
13 |
def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
|
|
|
27 |
else:
|
28 |
return gr.Image.update(visible=False), gr.Button.update(visible=False)
|
29 |
|
30 |
+
demo = gr.Blocks(title="DISCO")
|
31 |
with demo:
|
32 |
+
gr.Markdown(value="""
|
33 |
+
**Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors. [Project Page](https://menghanxia.github.io/projects/disco.html)**.
|
34 |
+
""")
|
35 |
with gr.Row():
|
36 |
with gr.Column(scale=1):
|
37 |
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
|
|
|
52 |
Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
|
53 |
outputs=Image_output)
|
54 |
## guiline
|
55 |
+
gr.Markdown(value="""
|
56 |
+
**Usage Guideline**
|
57 |
+
1. upload your image;
|
58 |
2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
|
59 |
+
3. Run the colorization (two modes supported):
|
|
|
60 |
- **Automatic**: click "Colorize" to get the automatically colorized output.
|
61 |
+
- **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (only anchor region will be used). Finally, click "Colorize" to get the result.
|
62 |
+
""")
|
63 |
+
gr.HTML(value="""
|
64 |
+
<p style='text-align: center'><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
|
65 |
""")
|
66 |
|
67 |
+
#demo.launch(server_name='9.134.253.83',server_port=7788)
|
68 |
+
demo.launch()
|