Spaces:
Running
on
Zero
Running
on
Zero
chore: add mistral model and bump dartrs version
Browse files
app.py
CHANGED
@@ -14,12 +14,12 @@ from utils import (
|
|
14 |
|
15 |
|
16 |
NORMALIZE_RATING_TAG = {
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
"
|
23 |
}
|
24 |
|
25 |
|
|
|
14 |
|
15 |
|
16 |
NORMALIZE_RATING_TAG = {
|
17 |
+
"sfw": "",
|
18 |
+
"general": "",
|
19 |
+
"sensitive": "sensitive",
|
20 |
+
"nsfw": "nsfw",
|
21 |
+
"questionable": "nsfw",
|
22 |
+
"explicit": "nsfw, explicit",
|
23 |
}
|
24 |
|
25 |
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ transformers==4.38.2
|
|
4 |
optimum[onnxruntime]==1.19.1
|
5 |
diffusers==0.27.2
|
6 |
spaces==0.26.2
|
7 |
-
dartrs==0.1.
|
|
|
|
4 |
optimum[onnxruntime]==1.19.1
|
5 |
diffusers==0.27.2
|
6 |
spaces==0.26.2
|
7 |
+
dartrs==0.1.3
|
8 |
+
dotenv
|
utils.py
CHANGED
@@ -25,33 +25,33 @@ IMAGE_SIZES = {
|
|
25 |
"640x1536": (640, 1536),
|
26 |
}
|
27 |
|
28 |
-
ASPECT_RATIO_OPTIONS:
|
29 |
-
"ultra_wide"
|
30 |
-
"wide"
|
31 |
-
"square"
|
32 |
-
"tall"
|
33 |
-
"ultra_tall"
|
34 |
-
|
35 |
-
RATING_OPTIONS:
|
36 |
-
"sfw"
|
37 |
-
"general"
|
38 |
-
"sensitive"
|
39 |
-
"nsfw"
|
40 |
-
"questionable"
|
41 |
-
"explicit"
|
42 |
-
|
43 |
-
LENGTH_OPTIONS:
|
44 |
-
"very_short"
|
45 |
-
"short"
|
46 |
-
"medium"
|
47 |
-
"long"
|
48 |
-
"very_long"
|
49 |
-
|
50 |
-
IDENTITY_OPTIONS:
|
51 |
-
"none"
|
52 |
-
"lax"
|
53 |
-
"strict"
|
54 |
-
|
55 |
|
56 |
|
57 |
PEOPLE_TAGS = [
|
|
|
25 |
"640x1536": (640, 1536),
|
26 |
}
|
27 |
|
28 |
+
ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
29 |
+
"ultra_wide",
|
30 |
+
"wide",
|
31 |
+
"square",
|
32 |
+
"tall",
|
33 |
+
"ultra_tall",
|
34 |
+
]
|
35 |
+
RATING_OPTIONS: list[RatingTag] = [
|
36 |
+
"sfw",
|
37 |
+
"general",
|
38 |
+
"sensitive",
|
39 |
+
"nsfw",
|
40 |
+
"questionable",
|
41 |
+
"explicit",
|
42 |
+
]
|
43 |
+
LENGTH_OPTIONS: list[LengthTag] = [
|
44 |
+
"very_short",
|
45 |
+
"short",
|
46 |
+
"medium",
|
47 |
+
"long",
|
48 |
+
"very_long",
|
49 |
+
]
|
50 |
+
IDENTITY_OPTIONS: list[IdentityTag] = [
|
51 |
+
"none",
|
52 |
+
"lax",
|
53 |
+
"strict",
|
54 |
+
]
|
55 |
|
56 |
|
57 |
PEOPLE_TAGS = [
|
v2.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
import time
|
2 |
import os
|
3 |
-
|
4 |
import torch
|
|
|
|
|
|
|
5 |
|
6 |
from dartrs.v2 import (
|
7 |
V2Model,
|
8 |
MixtralModel,
|
|
|
9 |
compose_prompt,
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
from dartrs.dartrs import DartTokenizer
|
12 |
from dartrs.utils import get_generation_config
|
@@ -30,11 +37,16 @@ from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY
|
|
30 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
31 |
|
32 |
ALL_MODELS = {
|
33 |
-
"dart-v2-
|
34 |
-
"repo": "p1atdev/dart-v2-
|
35 |
"type": "sft",
|
36 |
"class": MixtralModel,
|
37 |
},
|
|
|
|
|
|
|
|
|
|
|
38 |
}
|
39 |
|
40 |
|
@@ -49,21 +61,9 @@ def prepare_models(model_config: dict):
|
|
49 |
}
|
50 |
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
# tokenizer.batch_decode(
|
56 |
-
# [
|
57 |
-
# token
|
58 |
-
# for token in tokenizer.encode_plus(
|
59 |
-
# tags.strip(),
|
60 |
-
# return_tensors="pt",
|
61 |
-
# ).input_ids[0]
|
62 |
-
# if int(token) != tokenizer.unk_token_id
|
63 |
-
# ],
|
64 |
-
# skip_special_tokens=True,
|
65 |
-
# )
|
66 |
-
# )
|
67 |
|
68 |
|
69 |
@torch.no_grad()
|
@@ -102,10 +102,10 @@ class V2UI:
|
|
102 |
copyright_tags: str,
|
103 |
character_tags: str,
|
104 |
general_tags: str,
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
ban_tags: str,
|
110 |
*args,
|
111 |
) -> UpsamplingOutput:
|
@@ -120,10 +120,6 @@ class V2UI:
|
|
120 |
# character_tags = normalize_tags(self.tokenizer, character_tags)
|
121 |
# general_tags = normalize_tags(self.tokenizer, general_tags)
|
122 |
|
123 |
-
rating_tag = RATING_OPTIONS[rating_option]
|
124 |
-
aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
|
125 |
-
length_tag = LENGTH_OPTIONS[length_option]
|
126 |
-
identity_tag = IDENTITY_OPTIONS[identity_option]
|
127 |
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
|
128 |
|
129 |
prompt = compose_prompt(
|
@@ -175,7 +171,7 @@ class V2UI:
|
|
175 |
|
176 |
input_rating = gr.Radio(
|
177 |
label="Rating",
|
178 |
-
choices=list(RATING_OPTIONS
|
179 |
value="general",
|
180 |
)
|
181 |
input_aspect_ratio = gr.Radio(
|
@@ -187,13 +183,13 @@ class V2UI:
|
|
187 |
input_length = gr.Radio(
|
188 |
label="Length",
|
189 |
info="The total length of the tags.",
|
190 |
-
choices=list(LENGTH_OPTIONS
|
191 |
value="long",
|
192 |
)
|
193 |
input_identity = gr.Radio(
|
194 |
label="Keep identity",
|
195 |
info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
|
196 |
-
choices=list(IDENTITY_OPTIONS
|
197 |
value="none",
|
198 |
)
|
199 |
|
|
|
1 |
import time
|
2 |
import os
|
|
|
3 |
import torch
|
4 |
+
import dotenv
|
5 |
+
|
6 |
+
dotenv.load_dotenv()
|
7 |
|
8 |
from dartrs.v2 import (
|
9 |
V2Model,
|
10 |
MixtralModel,
|
11 |
+
MistralModel,
|
12 |
compose_prompt,
|
13 |
+
LengthTag,
|
14 |
+
AspectRatioTag,
|
15 |
+
RatingTag,
|
16 |
+
IdentityTag,
|
17 |
)
|
18 |
from dartrs.dartrs import DartTokenizer
|
19 |
from dartrs.utils import get_generation_config
|
|
|
37 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
38 |
|
39 |
ALL_MODELS = {
|
40 |
+
"dart-v2-moe-sft": {
|
41 |
+
"repo": "p1atdev/dart-v2-moe-sft",
|
42 |
"type": "sft",
|
43 |
"class": MixtralModel,
|
44 |
},
|
45 |
+
"dart-v2-sft": {
|
46 |
+
"repo": "p1atdev/dart-v2-sft",
|
47 |
+
"type": "sft",
|
48 |
+
"class": MistralModel,
|
49 |
+
},
|
50 |
}
|
51 |
|
52 |
|
|
|
61 |
}
|
62 |
|
63 |
|
64 |
+
def normalize_tags(tokenizer: DartTokenizer, tags: str):
|
65 |
+
"""Just remove unk tokens."""
|
66 |
+
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
@torch.no_grad()
|
|
|
102 |
copyright_tags: str,
|
103 |
character_tags: str,
|
104 |
general_tags: str,
|
105 |
+
rating_tag: RatingTag,
|
106 |
+
aspect_ratio_tag: AspectRatioTag,
|
107 |
+
length_tag: LengthTag,
|
108 |
+
identity_tag: IdentityTag,
|
109 |
ban_tags: str,
|
110 |
*args,
|
111 |
) -> UpsamplingOutput:
|
|
|
120 |
# character_tags = normalize_tags(self.tokenizer, character_tags)
|
121 |
# general_tags = normalize_tags(self.tokenizer, general_tags)
|
122 |
|
|
|
|
|
|
|
|
|
123 |
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
|
124 |
|
125 |
prompt = compose_prompt(
|
|
|
171 |
|
172 |
input_rating = gr.Radio(
|
173 |
label="Rating",
|
174 |
+
choices=list(RATING_OPTIONS),
|
175 |
value="general",
|
176 |
)
|
177 |
input_aspect_ratio = gr.Radio(
|
|
|
183 |
input_length = gr.Radio(
|
184 |
label="Length",
|
185 |
info="The total length of the tags.",
|
186 |
+
choices=list(LENGTH_OPTIONS),
|
187 |
value="long",
|
188 |
)
|
189 |
input_identity = gr.Radio(
|
190 |
label="Keep identity",
|
191 |
info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
|
192 |
+
choices=list(IDENTITY_OPTIONS),
|
193 |
value="none",
|
194 |
)
|
195 |
|