p1atdev commited on
Commit
08bd9a8
1 Parent(s): 680ded8

chore: use HF_TOKEN

Browse files
Files changed (2) hide show
  1. app.py +10 -10
  2. v2.py +17 -3
app.py CHANGED
@@ -162,6 +162,15 @@ def main():
162
 
163
  gr.Examples(
164
  examples=[
 
 
 
 
 
 
 
 
 
165
  [
166
  "original",
167
  "",
@@ -216,22 +225,13 @@ def main():
216
  "long",
217
  "lax",
218
  ],
219
- [
220
- "bocchi the rock!",
221
- "gotoh hitori, kita ikuyo, ijichi nijika, yamada ryo",
222
- "4girls, multiple girls",
223
- "sfw",
224
- "ultra_wide",
225
- "very_long",
226
- "lax",
227
- ],
228
  [
229
  "chuunibyou demo koi ga shitai!",
230
  "takanashi rikka",
231
  "1girl, solo",
232
  "sfw",
233
  "ultra_tall",
234
- "long",
235
  "lax",
236
  ],
237
  ],
 
162
 
163
  gr.Examples(
164
  examples=[
165
+ [
166
+ "original",
167
+ "",
168
+ "1girl, solo, upper body, :d",
169
+ "general",
170
+ "tall",
171
+ "long",
172
+ "none",
173
+ ],
174
  [
175
  "original",
176
  "",
 
225
  "long",
226
  "lax",
227
  ],
 
 
 
 
 
 
 
 
 
228
  [
229
  "chuunibyou demo koi ga shitai!",
230
  "takanashi rikka",
231
  "1girl, solo",
232
  "sfw",
233
  "ultra_tall",
234
+ "medium",
235
  "lax",
236
  ],
237
  ],
v2.py CHANGED
@@ -1,4 +1,5 @@
1
  import time
 
2
 
3
  import torch
4
 
@@ -26,6 +27,8 @@ except ImportError:
26
  from output import UpsamplingOutput
27
  from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
28
 
 
 
29
  ALL_MODELS = {
30
  "dart-v2-mixtral-160m-sft-6": {
31
  "repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
@@ -42,8 +45,8 @@ ALL_MODELS = {
42
 
43
  def prepare_models(model_config: dict):
44
  model_name = model_config["repo"]
45
- tokenizer = DartTokenizer.from_pretrained(model_name)
46
- model = model_config["class"].from_pretrained(model_name)
47
 
48
  return {
49
  "tokenizer": tokenizer,
@@ -74,6 +77,7 @@ def generate_tags(
74
  model: V2Model,
75
  tokenizer: DartTokenizer,
76
  prompt: str,
 
77
  ):
78
  output = model.generate(
79
  get_generation_config(
@@ -83,6 +87,7 @@ def generate_tags(
83
  top_p=0.9,
84
  top_k=100,
85
  max_new_tokens=256,
 
86
  ),
87
  )
88
 
@@ -107,7 +112,7 @@ class V2UI:
107
  aspect_ratio_option: str,
108
  length_option: str,
109
  identity_option: str,
110
- # image_size: str, # this is from image generation config
111
  *args,
112
  ) -> UpsamplingOutput:
113
  if self.model_name is None or self.model_name != model_name:
@@ -125,6 +130,7 @@ class V2UI:
125
  aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
126
  length_tag = LENGTH_OPTIONS[length_option]
127
  identity_tag = IDENTITY_OPTIONS[identity_option]
 
128
 
129
  prompt = compose_prompt(
130
  prompt=general_tags,
@@ -141,6 +147,7 @@ class V2UI:
141
  self.model,
142
  self.tokenizer,
143
  prompt,
 
144
  )
145
  elapsed_time = time.time() - start
146
 
@@ -193,6 +200,12 @@ class V2UI:
193
  value="none",
194
  )
195
 
 
 
 
 
 
 
196
  model_name = gr.Dropdown(
197
  label="Model",
198
  choices=list(ALL_MODELS.keys()),
@@ -210,6 +223,7 @@ class V2UI:
210
  input_aspect_ratio,
211
  input_length,
212
  input_identity,
 
213
  ]
214
 
215
  def get_generate_btn(self) -> gr.Button:
 
1
  import time
2
+ import os
3
 
4
  import torch
5
 
 
27
  from output import UpsamplingOutput
28
  from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
29
 
30
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
31
+
32
  ALL_MODELS = {
33
  "dart-v2-mixtral-160m-sft-6": {
34
  "repo": "p1atdev/dart-v2-mixtral-160m-sft-6",
 
45
 
46
  def prepare_models(model_config: dict):
47
  model_name = model_config["repo"]
48
+ tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
49
+ model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
50
 
51
  return {
52
  "tokenizer": tokenizer,
 
77
  model: V2Model,
78
  tokenizer: DartTokenizer,
79
  prompt: str,
80
+ ban_token_ids: list[int],
81
  ):
82
  output = model.generate(
83
  get_generation_config(
 
87
  top_p=0.9,
88
  top_k=100,
89
  max_new_tokens=256,
90
+ ban_token_ids=ban_token_ids,
91
  ),
92
  )
93
 
 
112
  aspect_ratio_option: str,
113
  length_option: str,
114
  identity_option: str,
115
+ ban_tags: str,
116
  *args,
117
  ) -> UpsamplingOutput:
118
  if self.model_name is None or self.model_name != model_name:
 
130
  aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
131
  length_tag = LENGTH_OPTIONS[length_option]
132
  identity_tag = IDENTITY_OPTIONS[identity_option]
133
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
134
 
135
  prompt = compose_prompt(
136
  prompt=general_tags,
 
147
  self.model,
148
  self.tokenizer,
149
  prompt,
150
+ ban_token_ids,
151
  )
152
  elapsed_time = time.time() - start
153
 
 
200
  value="none",
201
  )
202
 
203
+ with gr.Accordion(label="Advanced options", open=False):
204
+ input_ban_tags = gr.Textbox(
205
+ label="Ban tags",
206
+ placeholder="alternate costumen, ...",
207
+ )
208
+
209
  model_name = gr.Dropdown(
210
  label="Model",
211
  choices=list(ALL_MODELS.keys()),
 
223
  input_aspect_ratio,
224
  input_length,
225
  input_identity,
226
+ input_ban_tags,
227
  ]
228
 
229
  def get_generate_btn(self) -> gr.Button: