Spaces:
Running
on
Zero
Running
on
Zero
chore: use HF_TOKEN
Browse files
app.py
CHANGED
@@ -162,6 +162,15 @@ def main():
|
|
162 |
|
163 |
gr.Examples(
|
164 |
examples=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
[
|
166 |
"original",
|
167 |
"",
|
@@ -216,22 +225,13 @@ def main():
|
|
216 |
"long",
|
217 |
"lax",
|
218 |
],
|
219 |
-
[
|
220 |
-
"bocchi the rock!",
|
221 |
-
"gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
|
222 |
-
"4girls, multiple girls",
|
223 |
-
"sfw",
|
224 |
-
"ultra_wide",
|
225 |
-
"very_long",
|
226 |
-
"lax",
|
227 |
-
],
|
228 |
[
|
229 |
"chuunibyou demo koi ga shitai!",
|
230 |
"takanashi rikka",
|
231 |
"1girl, solo",
|
232 |
"sfw",
|
233 |
"ultra_tall",
|
234 |
-
"
|
235 |
"lax",
|
236 |
],
|
237 |
],
|
|
|
162 |
|
163 |
gr.Examples(
|
164 |
examples=[
|
165 |
+
[
|
166 |
+
"original",
|
167 |
+
"",
|
168 |
+
"1girl, solo, upper body, :d",
|
169 |
+
"general",
|
170 |
+
"tall",
|
171 |
+
"long",
|
172 |
+
"none",
|
173 |
+
],
|
174 |
[
|
175 |
"original",
|
176 |
"",
|
|
|
225 |
"long",
|
226 |
"lax",
|
227 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
[
|
229 |
"chuunibyou demo koi ga shitai!",
|
230 |
"takanashi rikka",
|
231 |
"1girl, solo",
|
232 |
"sfw",
|
233 |
"ultra_tall",
|
234 |
+
"medium",
|
235 |
"lax",
|
236 |
],
|
237 |
],
|
v2.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import time
|
|
|
2 |
|
3 |
import torch
|
4 |
|
@@ -26,6 +27,8 @@ except ImportError:
|
|
26 |
from output import UpsamplingOutput
|
27 |
from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
|
28 |
|
|
|
|
|
29 |
ALL_MODELS = {
|
30 |
"dart-v2-mixtral-160m-sft-6": {
|
31 |
"repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
|
@@ -42,8 +45,8 @@ ALL_MODELS = {
|
|
42 |
|
43 |
def prepare_models(model_config: dict):
|
44 |
model_name = model_config["repo"]
|
45 |
-
tokenizer = DartTokenizer.from_pretrained(model_name)
|
46 |
-
model = model_config["class"].from_pretrained(model_name)
|
47 |
|
48 |
return {
|
49 |
"tokenizer": tokenizer,
|
@@ -74,6 +77,7 @@ def generate_tags(
|
|
74 |
model: V2Model,
|
75 |
tokenizer: DartTokenizer,
|
76 |
prompt: str,
|
|
|
77 |
):
|
78 |
output = model.generate(
|
79 |
get_generation_config(
|
@@ -83,6 +87,7 @@ def generate_tags(
|
|
83 |
top_p=0.9,
|
84 |
top_k=100,
|
85 |
max_new_tokens=256,
|
|
|
86 |
),
|
87 |
)
|
88 |
|
@@ -107,7 +112,7 @@ class V2UI:
|
|
107 |
aspect_ratio_option: str,
|
108 |
length_option: str,
|
109 |
identity_option: str,
|
110 |
-
|
111 |
*args,
|
112 |
) -> UpsamplingOutput:
|
113 |
if self.model_name is None or self.model_name != model_name:
|
@@ -125,6 +130,7 @@ class V2UI:
|
|
125 |
aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
|
126 |
length_tag = LENGTH_OPTIONS[length_option]
|
127 |
identity_tag = IDENTITY_OPTIONS[identity_option]
|
|
|
128 |
|
129 |
prompt = compose_prompt(
|
130 |
prompt=general_tags,
|
@@ -141,6 +147,7 @@ class V2UI:
|
|
141 |
self.model,
|
142 |
self.tokenizer,
|
143 |
prompt,
|
|
|
144 |
)
|
145 |
elapsed_time = time.time() - start
|
146 |
|
@@ -193,6 +200,12 @@ class V2UI:
|
|
193 |
value="none",
|
194 |
)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
model_name = gr.Dropdown(
|
197 |
label="Model",
|
198 |
choices=list(ALL_MODELS.keys()),
|
@@ -210,6 +223,7 @@ class V2UI:
|
|
210 |
input_aspect_ratio,
|
211 |
input_length,
|
212 |
input_identity,
|
|
|
213 |
]
|
214 |
|
215 |
def get_generate_btn(self) -> gr.Button:
|
|
|
1 |
import time
|
2 |
+
import os
|
3 |
|
4 |
import torch
|
5 |
|
|
|
27 |
from output import UpsamplingOutput
|
28 |
from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
|
29 |
|
30 |
+
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
31 |
+
|
32 |
ALL_MODELS = {
|
33 |
"dart-v2-mixtral-160m-sft-6": {
|
34 |
"repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
|
|
|
45 |
|
46 |
def prepare_models(model_config: dict):
|
47 |
model_name = model_config["repo"]
|
48 |
+
tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
|
49 |
+
model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
|
50 |
|
51 |
return {
|
52 |
"tokenizer": tokenizer,
|
|
|
77 |
model: V2Model,
|
78 |
tokenizer: DartTokenizer,
|
79 |
prompt: str,
|
80 |
+
ban_token_ids: list[int],
|
81 |
):
|
82 |
output = model.generate(
|
83 |
get_generation_config(
|
|
|
87 |
top_p=0.9,
|
88 |
top_k=100,
|
89 |
max_new_tokens=256,
|
90 |
+
ban_token_ids=ban_token_ids,
|
91 |
),
|
92 |
)
|
93 |
|
|
|
112 |
aspect_ratio_option: str,
|
113 |
length_option: str,
|
114 |
identity_option: str,
|
115 |
+
ban_tags: str,
|
116 |
*args,
|
117 |
) -> UpsamplingOutput:
|
118 |
if self.model_name is None or self.model_name != model_name:
|
|
|
130 |
aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
|
131 |
length_tag = LENGTH_OPTIONS[length_option]
|
132 |
identity_tag = IDENTITY_OPTIONS[identity_option]
|
133 |
+
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
|
134 |
|
135 |
prompt = compose_prompt(
|
136 |
prompt=general_tags,
|
|
|
147 |
self.model,
|
148 |
self.tokenizer,
|
149 |
prompt,
|
150 |
+
ban_token_ids,
|
151 |
)
|
152 |
elapsed_time = time.time() - start
|
153 |
|
|
|
200 |
value="none",
|
201 |
)
|
202 |
|
203 |
+
with gr.Accordion(label="Advanced options", open=False):
|
204 |
+
input_ban_tags = gr.Textbox(
|
205 |
+
label="Ban tags",
|
206 |
+
placeholder="alternate costumen, ...",
|
207 |
+
)
|
208 |
+
|
209 |
model_name = gr.Dropdown(
|
210 |
label="Model",
|
211 |
choices=list(ALL_MODELS.keys()),
|
|
|
223 |
input_aspect_ratio,
|
224 |
input_length,
|
225 |
input_identity,
|
226 |
+
input_ban_tags,
|
227 |
]
|
228 |
|
229 |
def get_generate_btn(self) -> gr.Button:
|