p1atdev commited on
Commit
289c297
1 Parent(s): d80fbc9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -3
README.md CHANGED
@@ -38,7 +38,7 @@ MODEL_NAME = "p1atdev/dart-v1-base"
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # trust_remote_code is required for tokenizer
39
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
40
 
41
- prompt = "<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general>1girl, "
42
  inputs = tokenizer(prompt, return_tensors="pt").input_ids
43
 
44
  with torch.no_grad():
@@ -48,6 +48,23 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
48
  # rating:sfw, rating:general, original, 1girl, ahoge, black hair, blue eyes, blush, closed mouth, ear piercing, earrings, jewelry, looking at viewer, mole, mole under eye, piercing, portrait, shirt, short hair, solo, white shirt
49
  ```
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  #### Flash attention (optional)
52
 
53
  Using flash attention can optimize computations, but it is currently only compatible with Linux.
@@ -86,8 +103,12 @@ ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME)
86
  # qunatized version
87
  # ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME, file_name="model_quantized.onnx")
88
 
89
- prompt = "<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general>1girl, "
90
- inputs = tokenizer(prompt, return_tensors="pt").input_ids
 
 
 
 
91
 
92
  with torch.no_grad():
93
  outputs = model.generate(inputs, generation_config=generation_config)
 
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # trust_remote_code is required for tokenizer
39
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
40
 
41
+ prompt = "<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general>1girl"
42
  inputs = tokenizer(prompt, return_tensors="pt").input_ids
43
 
44
  with torch.no_grad():
 
48
  # rating:sfw, rating:general, original, 1girl, ahoge, black hair, blue eyes, blush, closed mouth, ear piercing, earrings, jewelry, looking at viewer, mole, mole under eye, piercing, portrait, shirt, short hair, solo, white shirt
49
  ```
50
 
51
+ You can use `tokenizer.apply_chat_template` to simplify constructiing of prompts:
52
+
53
+ ```py
54
+ inputs = tokenizer.apply_chat_template({
55
+ "rating": "rating:sfw, rating:general",
56
+ "copyright": "original",
57
+ "character": "",
58
+ "general": "1girl"
59
+ }, tokenize=True) # tokenize=False to preview prompt
60
+ # same as input_ids of "<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general>1girl"
61
+
62
+ with torch.no_grad():
63
+ outputs = model.generate(inputs, generation_config=generation_config)
64
+ ```
65
+
66
+ See [chat_templating document](https://huggingface.co/docs/transformers/main/en/chat_templating) for more detail about `apply_chat_template`.
67
+
68
  #### Flash attention (optional)
69
 
70
  Using flash attention can optimize computations, but it is currently only compatible with Linux.
 
103
  # qunatized version
104
  # ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME, file_name="model_quantized.onnx")
105
 
106
+ inputs = tokenizer.apply_chat_template({
107
+ "rating": "rating:sfw, rating:general",
108
+ "copyright": "original",
109
+ "character": "",
110
+ "general": "1girl"
111
+ }, tokenize=True)
112
 
113
  with torch.no_grad():
114
  outputs = model.generate(inputs, generation_config=generation_config)