squarelike
commited on
Commit
•
2aa4a17
1
Parent(s):
db7b33e
Update README.md
Browse files
README.md
CHANGED
@@ -36,7 +36,7 @@ I trained with 1x A6000 GPUs for 90 hours.
|
|
36 |
```python
|
37 |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
38 |
import torch
|
39 |
-
repo = "squarelike/Gugugo-koen-7B-V1.1
|
40 |
model = AutoModelForCausalLM.from_pretrained(
|
41 |
repo,
|
42 |
load_in_4bit=True
|
@@ -56,7 +56,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
56 |
|
57 |
return False
|
58 |
|
59 |
-
stop_words_ids = torch.tensor([[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]])
|
60 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
61 |
|
62 |
def gen(lan="en", x=""):
|
@@ -69,13 +69,10 @@ def gen(lan="en", x=""):
|
|
69 |
prompt,
|
70 |
return_tensors='pt',
|
71 |
return_token_type_ids=False
|
72 |
-
),
|
73 |
-
max_new_tokens=
|
74 |
temperature=0.1,
|
75 |
-
no_repeat_ngram_size=10,
|
76 |
-
early_stopping=True,
|
77 |
do_sample=True,
|
78 |
-
eos_token_id=2,
|
79 |
stopping_criteria=stopping_criteria
|
80 |
)
|
81 |
return tokenizer.decode(gened[0][1:]).replace(prompt+" ", "").replace("</끝>", "")
|
|
|
36 |
```python
|
37 |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
38 |
import torch
|
39 |
+
repo = "squarelike/Gugugo-koen-7B-V1.1"
|
40 |
model = AutoModelForCausalLM.from_pretrained(
|
41 |
repo,
|
42 |
load_in_4bit=True
|
|
|
56 |
|
57 |
return False
|
58 |
|
59 |
+
stop_words_ids = torch.tensor([[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]]).to("cuda")
|
60 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
61 |
|
62 |
def gen(lan="en", x=""):
|
|
|
69 |
prompt,
|
70 |
return_tensors='pt',
|
71 |
return_token_type_ids=False
|
72 |
+
).to("cuda"),
|
73 |
+
max_new_tokens=2000,
|
74 |
temperature=0.1,
|
|
|
|
|
75 |
do_sample=True,
|
|
|
76 |
stopping_criteria=stopping_criteria
|
77 |
)
|
78 |
return tokenizer.decode(gened[0][1:]).replace(prompt+" ", "").replace("</끝>", "")
|