p1atdev commited on
Commit
b783233
1 Parent(s): 6994082

chore: add mistral model and bump dartrs version

Browse files
Files changed (4) hide show
  1. app.py +6 -6
  2. requirements.txt +2 -1
  3. utils.py +27 -27
  4. v2.py +25 -29
app.py CHANGED
@@ -14,12 +14,12 @@ from utils import (
14
 
15
 
16
  NORMALIZE_RATING_TAG = {
17
- "<|rating:sfw|>": "",
18
- "<|rating:general|>": "",
19
- "<|rating:sensitive|>": "sensitive",
20
- "<|rating:nsfw|>": "nsfw",
21
- "<|rating:questionable|>": "nsfw",
22
- "<|rating:explicit|>": "nsfw, explicit",
23
  }
24
 
25
 
 
14
 
15
 
16
  NORMALIZE_RATING_TAG = {
17
+ "sfw": "",
18
+ "general": "",
19
+ "sensitive": "sensitive",
20
+ "nsfw": "nsfw",
21
+ "questionable": "nsfw",
22
+ "explicit": "nsfw, explicit",
23
  }
24
 
25
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers==4.38.2
4
  optimum[onnxruntime]==1.19.1
5
  diffusers==0.27.2
6
  spaces==0.26.2
7
- dartrs==0.1.2
 
 
4
  optimum[onnxruntime]==1.19.1
5
  diffusers==0.27.2
6
  spaces==0.26.2
7
+ dartrs==0.1.3
8
+ dotenv
utils.py CHANGED
@@ -25,33 +25,33 @@ IMAGE_SIZES = {
25
  "640x1536": (640, 1536),
26
  }
27
 
28
- ASPECT_RATIO_OPTIONS: dict[str, AspectRatioTag] = {
29
- "ultra_wide": "<|aspect_ratio:ultra_wide|>",
30
- "wide": "<|aspect_ratio:wide|>",
31
- "square": "<|aspect_ratio:square|>",
32
- "tall": "<|aspect_ratio:tall|>",
33
- "ultra_tall": "<|aspect_ratio:ultra_tall|>",
34
- }
35
- RATING_OPTIONS: dict[str, RatingTag] = {
36
- "sfw": "<|rating:sfw|>",
37
- "general": "<|rating:general|>",
38
- "sensitive": "<|rating:sensitive|>",
39
- "nsfw": "<|rating:nsfw|>",
40
- "questionable": "<|rating:questionable|>",
41
- "explicit": "<|rating:explicit|>",
42
- }
43
- LENGTH_OPTIONS: dict[str, LengthTag] = {
44
- "very_short": "<|length:very_short|>",
45
- "short": "<|length:short|>",
46
- "medium": "<|length:medium|>",
47
- "long": "<|length:long|>",
48
- "very_long": "<|length:very_long|>",
49
- }
50
- IDENTITY_OPTIONS: dict[str, IdentityTag] = {
51
- "none": "<|identity:none|>",
52
- "lax": "<|identity:lax|>",
53
- "strict": "<|identity:strict|>",
54
- }
55
 
56
 
57
  PEOPLE_TAGS = [
 
25
  "640x1536": (640, 1536),
26
  }
27
 
28
+ ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
29
+ "ultra_wide",
30
+ "wide",
31
+ "square",
32
+ "tall",
33
+ "ultra_tall",
34
+ ]
35
+ RATING_OPTIONS: list[RatingTag] = [
36
+ "sfw",
37
+ "general",
38
+ "sensitive",
39
+ "nsfw",
40
+ "questionable",
41
+ "explicit",
42
+ ]
43
+ LENGTH_OPTIONS: list[LengthTag] = [
44
+ "very_short",
45
+ "short",
46
+ "medium",
47
+ "long",
48
+ "very_long",
49
+ ]
50
+ IDENTITY_OPTIONS: list[IdentityTag] = [
51
+ "none",
52
+ "lax",
53
+ "strict",
54
+ ]
55
 
56
 
57
  PEOPLE_TAGS = [
v2.py CHANGED
@@ -1,12 +1,19 @@
1
  import time
2
  import os
3
-
4
  import torch
 
 
 
5
 
6
  from dartrs.v2 import (
7
  V2Model,
8
  MixtralModel,
 
9
  compose_prompt,
 
 
 
 
10
  )
11
  from dartrs.dartrs import DartTokenizer
12
  from dartrs.utils import get_generation_config
@@ -30,11 +37,16 @@ from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY
30
  HF_TOKEN = os.getenv("HF_TOKEN", None)
31
 
32
  ALL_MODELS = {
33
- "dart-v2-mixtral-160m-sft": {
34
- "repo": "p1atdev/dart-v2-mixtral-160m-sft-8",
35
  "type": "sft",
36
  "class": MixtralModel,
37
  },
 
 
 
 
 
38
  }
39
 
40
 
@@ -49,21 +61,9 @@ def prepare_models(model_config: dict):
49
  }
50
 
51
 
52
- # def normalize_tags(tokenizer: PreTrainedTokenizerBase, tags: str):
53
- # """Just remove unk tokens."""
54
- # return ", ".join(
55
- # tokenizer.batch_decode(
56
- # [
57
- # token
58
- # for token in tokenizer.encode_plus(
59
- # tags.strip(),
60
- # return_tensors="pt",
61
- # ).input_ids[0]
62
- # if int(token) != tokenizer.unk_token_id
63
- # ],
64
- # skip_special_tokens=True,
65
- # )
66
- # )
67
 
68
 
69
  @torch.no_grad()
@@ -102,10 +102,10 @@ class V2UI:
102
  copyright_tags: str,
103
  character_tags: str,
104
  general_tags: str,
105
- rating_option: str,
106
- aspect_ratio_option: str,
107
- length_option: str,
108
- identity_option: str,
109
  ban_tags: str,
110
  *args,
111
  ) -> UpsamplingOutput:
@@ -120,10 +120,6 @@ class V2UI:
120
  # character_tags = normalize_tags(self.tokenizer, character_tags)
121
  # general_tags = normalize_tags(self.tokenizer, general_tags)
122
 
123
- rating_tag = RATING_OPTIONS[rating_option]
124
- aspect_ratio_tag = ASPECT_RATIO_OPTIONS[aspect_ratio_option]
125
- length_tag = LENGTH_OPTIONS[length_option]
126
- identity_tag = IDENTITY_OPTIONS[identity_option]
127
  ban_token_ids = self.tokenizer.encode(ban_tags.strip())
128
 
129
  prompt = compose_prompt(
@@ -175,7 +171,7 @@ class V2UI:
175
 
176
  input_rating = gr.Radio(
177
  label="Rating",
178
- choices=list(RATING_OPTIONS.keys()),
179
  value="general",
180
  )
181
  input_aspect_ratio = gr.Radio(
@@ -187,13 +183,13 @@ class V2UI:
187
  input_length = gr.Radio(
188
  label="Length",
189
  info="The total length of the tags.",
190
- choices=list(LENGTH_OPTIONS.keys()),
191
  value="long",
192
  )
193
  input_identity = gr.Radio(
194
  label="Keep identity",
195
  info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
196
- choices=list(IDENTITY_OPTIONS.keys()),
197
  value="none",
198
  )
199
 
 
1
  import time
2
  import os
 
3
  import torch
4
+ import dotenv
5
+
6
+ dotenv.load_dotenv()
7
 
8
  from dartrs.v2 import (
9
  V2Model,
10
  MixtralModel,
11
+ MistralModel,
12
  compose_prompt,
13
+ LengthTag,
14
+ AspectRatioTag,
15
+ RatingTag,
16
+ IdentityTag,
17
  )
18
  from dartrs.dartrs import DartTokenizer
19
  from dartrs.utils import get_generation_config
 
37
  HF_TOKEN = os.getenv("HF_TOKEN", None)
38
 
39
  ALL_MODELS = {
40
+ "dart-v2-moe-sft": {
41
+ "repo": "p1atdev/dart-v2-moe-sft",
42
  "type": "sft",
43
  "class": MixtralModel,
44
  },
45
+ "dart-v2-sft": {
46
+ "repo": "p1atdev/dart-v2-sft",
47
+ "type": "sft",
48
+ "class": MistralModel,
49
+ },
50
  }
51
 
52
 
 
61
  }
62
 
63
 
64
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
65
+ """Just remove unk tokens."""
66
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  @torch.no_grad()
 
102
  copyright_tags: str,
103
  character_tags: str,
104
  general_tags: str,
105
+ rating_tag: RatingTag,
106
+ aspect_ratio_tag: AspectRatioTag,
107
+ length_tag: LengthTag,
108
+ identity_tag: IdentityTag,
109
  ban_tags: str,
110
  *args,
111
  ) -> UpsamplingOutput:
 
120
  # character_tags = normalize_tags(self.tokenizer, character_tags)
121
  # general_tags = normalize_tags(self.tokenizer, general_tags)
122
 
 
 
 
 
123
  ban_token_ids = self.tokenizer.encode(ban_tags.strip())
124
 
125
  prompt = compose_prompt(
 
171
 
172
  input_rating = gr.Radio(
173
  label="Rating",
174
+ choices=list(RATING_OPTIONS),
175
  value="general",
176
  )
177
  input_aspect_ratio = gr.Radio(
 
183
  input_length = gr.Radio(
184
  label="Length",
185
  info="The total length of the tags.",
186
+ choices=list(LENGTH_OPTIONS),
187
  value="long",
188
  )
189
  input_identity = gr.Radio(
190
  label="Keep identity",
191
  info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
192
+ choices=list(IDENTITY_OPTIONS),
193
  value="none",
194
  )
195