koziev ilya
commited on
Commit
·
e3f5def
1
Parent(s):
a31c007
немного причесал код, убрал лишние манипуляции с выдачей gpt
Browse files
README.md
CHANGED
@@ -44,6 +44,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
|
46 |
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
|
|
|
47 |
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
|
48 |
model.to(device)
|
49 |
|
@@ -51,8 +52,10 @@ model.to(device)
|
|
51 |
# В конце добавляем символ "#"
|
52 |
input_text = """<s>- Как тебя зовут?
|
53 |
- Джульетта Мао #"""
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
output_sequences = model.generate(
|
58 |
input_ids=encoded_prompt,
|
@@ -63,12 +66,10 @@ output_sequences = model.generate(
|
|
63 |
repetition_penalty=1.2,
|
64 |
do_sample=True,
|
65 |
num_return_sequences=1,
|
66 |
-
pad_token_id=
|
67 |
)
|
68 |
|
69 |
-
|
70 |
-
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
|
71 |
text = text[: text.find('</s>')]
|
72 |
-
text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
|
73 |
print(text)
|
74 |
```
|
|
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
|
46 |
tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
|
47 |
+
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
|
48 |
model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
|
49 |
model.to(device)
|
50 |
|
|
|
52 |
# В конце добавляем символ "#"
|
53 |
input_text = """<s>- Как тебя зовут?
|
54 |
- Джульетта Мао #"""
|
55 |
+
#input_text = """<s>- Что Предтечи забрали у Предшественников?
|
56 |
+
#- Они узурпировали у них Мантию — защиту всего живого в галактике #"""
|
57 |
+
|
58 |
+
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
|
59 |
|
60 |
output_sequences = model.generate(
|
61 |
input_ids=encoded_prompt,
|
|
|
66 |
repetition_penalty=1.2,
|
67 |
do_sample=True,
|
68 |
num_return_sequences=1,
|
69 |
+
pad_token_id=tokenizer.pad_token_id,
|
70 |
)
|
71 |
|
72 |
+
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
|
|
|
73 |
text = text[: text.find('</s>')]
|
|
|
74 |
print(text)
|
75 |
```
|