anzorq commited on
Commit
8014209
1 Parent(s): 3508df0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -48
app.py CHANGED
@@ -1,55 +1,85 @@
 
1
  import subprocess
2
- from huggingface_hub import HfApi, hf_hub_download
3
  import gradio as gr
 
4
 
5
  subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])
6
 
7
  def error_str(error, title="Error"):
8
  return f"""#### {title}
9
- {error}"""
 
 
 
 
 
 
 
10
 
11
  def url_to_model_id(model_id_str):
12
  return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
13
 
14
- def get_ckpt_names(model_id = "nitrosocke/mo-di-diffusion"):
15
 
16
- if model_id == "":
17
- return error_str("Please enter a model name.", title="Invalid input"), None, None
 
 
18
 
19
  try:
20
- api = HfApi()
21
- ckpt_files = [f for f in api.list_repo_files(url_to_model_id(model_id)) if f.endswith(".ckpt")]
22
 
23
- if len(ckpt_files) == 0:
24
- return error_str("No checkpoint files found in the model repo."), None, None
25
 
26
- return None, gr.update(choices=ckpt_files, visible=True), gr.update(visible=True)
27
 
28
  except Exception as e:
29
- return error_str(e), None, None
30
 
31
- def convert(model_id, ckpt_name, token = "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"):
32
 
33
- model_id = url_to_model_id(model_id)
34
-
35
- # 1. Download the checkpoint file
36
- ckpt_path = hf_hub_download(repo_id=model_id, filename=ckpt_name)
37
-
38
- # 2. Run the conversion script
39
- subprocess.run(
40
- [
41
- "python3",
42
- "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
43
- "--checkpoint_path",
44
- ckpt_path,
45
- "--dump_path" ,
46
- model_id,
47
- ]
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- import os
51
- return f"""files in current directory:
52
- {[f for f in os.listdir(".")]}"""
53
 
54
 
55
  with gr.Blocks() as demo:
@@ -57,40 +87,61 @@ with gr.Blocks() as demo:
57
  with gr.Row():
58
 
59
  with gr.Column(scale=11):
60
- with gr.Group():
61
  gr.Markdown("## 1. Load model info")
62
  input_token = gr.Textbox(
63
  max_lines=1,
64
- label="Hugging Face token",
65
  placeholder="hf_...",
66
  )
67
- gr.Markdown("Get your token [here](https://huggingface.co/settings/tokens).")
68
- input_model = gr.Textbox(
69
- max_lines=1,
70
- label="Model name or URL",
71
- placeholder="username/model_name",
72
- )
 
 
 
73
 
74
  btn_get_ckpts = gr.Button("Load")
75
 
76
- with gr.Column(scale=10, visible=False) as col_convert:
77
- gr.Markdown("## 2. Convert to Diffusers🧨")
78
- radio_ckpts = gr.Radio(label="Choose a checkpoint to convert", visible=False)
79
- btn_convert = gr.Button("Convert")
 
 
80
 
81
  error_output = gr.Markdown(label="Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  btn_get_ckpts.click(
83
  fn=get_ckpt_names,
84
- inputs=[input_model],
85
- outputs=[error_output, radio_ckpts, col_convert],
86
  scroll_to_output=True
87
  )
88
 
89
  btn_convert.click(
90
- fn=convert,
91
- inputs=[input_model, radio_ckpts, input_token],
92
  outputs=error_output,
93
  scroll_to_output=True
94
  )
95
 
96
- demo.launch()
 
 
1
+ import os
2
  import subprocess
3
+ from huggingface_hub import HfApi, upload_folder
4
  import gradio as gr
5
+ import hf_utils
6
 
7
  subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])
8
 
9
  def error_str(error, title="Error"):
10
  return f"""#### {title}
11
+ {error}""" if error else ""
12
+
13
+ def on_token_change(token):
14
+ model_names, error = hf_utils.get_my_model_names(token)
15
+ if model_names:
16
+ model_names.append("Other")
17
+
18
+ return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(value=error_str(error))
19
 
20
  def url_to_model_id(model_id_str):
21
  return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
22
 
23
+ def get_ckpt_names(token, radio_model_names, input_model):
24
 
25
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
26
+
27
+ if token == "" or model_id == "":
28
+ return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
29
 
30
  try:
31
+ api = HfApi(token=token)
32
+ ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
33
 
34
+ if not ckpt_files:
35
+ return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
36
 
37
+ return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
38
 
39
  except Exception as e:
40
+ return error_str(e), gr.update(choices=[]), None
41
 
42
+ def convert_and_push(radio_model_names, input_model, ckpt_name, token):
43
 
44
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
45
+
46
+ try:
47
+ model_id = url_to_model_id(model_id)
48
+
49
+ # 1. Download the checkpoint file
50
+ ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
51
+
52
+ # 2. Run the conversion script
53
+ subprocess.run(
54
+ [
55
+ "python3",
56
+ "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
57
+ "--checkpoint_path",
58
+ ckpt_path,
59
+ "--dump_path" ,
60
+ model_id,
61
+ ]
62
+ )
63
+
64
+ # 3. Push to the model repo
65
+ upload_folder(
66
+ folder_path=model_id,
67
+ repo_id=model_id,
68
+ token=token,
69
+ create_pr=True,
70
+ )
71
+
72
+ # # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
73
+ hf_utils.delete_file(revision)
74
+ subprocess.run(["rm", "-rf", model_id.split('/')[0]])
75
+ import glob
76
+ for f in glob.glob("*.yaml*"):
77
+ subprocess.run(["rm", "-rf", f])
78
+
79
+ return "Success"
80
 
81
+ except Exception as e:
82
+ return error_str(e)
 
83
 
84
 
85
  with gr.Blocks() as demo:
 
87
  with gr.Row():
88
 
89
  with gr.Column(scale=11):
90
+ with gr.Column():
91
  gr.Markdown("## 1. Load model info")
92
  input_token = gr.Textbox(
93
  max_lines=1,
94
+ label="Enter your Hugging Face token",
95
  placeholder="hf_...",
96
  )
97
+ gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens).")
98
+ with gr.Group(visible=False) as group_model:
99
+ radio_model_names = gr.Radio(label="Choose a model")
100
+ input_model = gr.Textbox(
101
+ max_lines=1,
102
+ label="Model name or URL",
103
+ placeholder="username/model_name",
104
+ visible=False,
105
+ )
106
 
107
  btn_get_ckpts = gr.Button("Load")
108
 
109
+ with gr.Column(scale=10):
110
+ with gr.Column(visible=False) as group_convert:
111
+ gr.Markdown("## 2. Convert to Diffusers🧨")
112
+ radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
113
+ gr.Markdown("Conversion may take a few minutes.")
114
+ btn_convert = gr.Button("Convert & Push")
115
 
116
  error_output = gr.Markdown(label="Output")
117
+
118
+ input_token.change(
119
+ fn=on_token_change,
120
+ inputs=input_token,
121
+ outputs=[group_model, radio_model_names, error_output],
122
+ queue=False,
123
+ scroll_to_output=True)
124
+
125
+ radio_model_names.change(
126
+ lambda x: gr.update(visible=x == "Other"),
127
+ inputs=radio_model_names,
128
+ outputs=input_model,
129
+ queue=False,
130
+ scroll_to_output=True)
131
+
132
  btn_get_ckpts.click(
133
  fn=get_ckpt_names,
134
+ inputs=[input_token, radio_model_names, input_model],
135
+ outputs=[error_output, radio_ckpts, group_convert],
136
  scroll_to_output=True
137
  )
138
 
139
  btn_convert.click(
140
+ fn=convert_and_push,
141
+ inputs=[radio_model_names, input_model, radio_ckpts, input_token],
142
  outputs=error_output,
143
  scroll_to_output=True
144
  )
145
 
146
+ demo.queue()
147
+ demo.launch()