p1atdev commited on
Commit
2c29184
1 Parent(s): 59bc547

Upload tokenization_dart.py

Browse files
Files changed (1) hide show
  1. tokenization_dart.py +10 -23
tokenization_dart.py CHANGED
@@ -1,7 +1,5 @@
1
  import logging
2
- import json
3
- from typing import Dict, List
4
- from pydantic.dataclasses import dataclass
5
 
6
  from transformers import PreTrainedTokenizerFast
7
  from tokenizers.decoders import Decoder
@@ -39,35 +37,24 @@ PROMPT_TEMPLATE = (
39
  "{{ '</character>' }}"
40
 
41
  "{{ '<general>' }}"
 
 
 
 
 
 
 
 
42
  "{% if 'general' not in messages or messages['general'] is none %}"
43
  "{{ '' }}"
44
  "{% else %}"
45
  "{{ messages['general'] }}"
46
  "{% endif %}"
 
47
  ).strip()
48
  # fmt: on
49
 
50
 
51
- @dataclass
52
- class Category:
53
- name: str
54
- bos_token_id: int
55
- eos_token_id: int
56
-
57
-
58
- @dataclass
59
- class TagCategoryConfig:
60
- categories: Dict[str, Category]
61
- category_to_token_ids: Dict[str, List[int]]
62
-
63
-
64
- def load_tag_category_config(config_json: str):
65
- with open(config_json, "rb") as file:
66
- config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
67
-
68
- return config
69
-
70
-
71
  class DartDecoder:
72
  def __init__(self, special_tokens: List[str]):
73
  self.special_tokens = list(special_tokens)
 
1
  import logging
2
+ from typing import List
 
 
3
 
4
  from transformers import PreTrainedTokenizerFast
5
  from tokenizers.decoders import Decoder
 
37
  "{{ '</character>' }}"
38
 
39
  "{{ '<general>' }}"
40
+ # length token
41
+ "{% if 'length' not in messages or messages['length'] is none %}"
42
+ "{{ '<|long|>' }}"
43
+ "{% else %}"
44
+ "{{ messages['length'] }}"
45
+ "{% endif %}"
46
+
47
+ # general token
48
  "{% if 'general' not in messages or messages['general'] is none %}"
49
  "{{ '' }}"
50
  "{% else %}"
51
  "{{ messages['general'] }}"
52
  "{% endif %}"
53
+ "{{ '<|input_end|>' }}"
54
  ).strip()
55
  # fmt: on
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  class DartDecoder:
59
  def __init__(self, special_tokens: List[str]):
60
  self.special_tokens = list(special_tokens)