koziev ilya commited on
Commit
5af1322
1 Parent(s): 82f47f0

Добавлен пример использования модели

Browse files
Files changed (1) hide show
  1. README.md +39 -2
README.md CHANGED
@@ -3,10 +3,47 @@ license: unlicense
3
  ---
4
 
5
 
6
- ## Incomplete Utterance Restoration
7
-
8
 
9
  Генеративная модель на основе [sberbank-ai/rugpt3large_based_on_gpt2](https://huggingface.co/sberbank-ai/rugpt3large_based_on_gpt2)
10
  для восстановления полного текста реплик в диалоге из контекста.
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
 
5
 
6
+ ## Задача Incomplete Utterance Restoration
 
7
 
8
  Генеративная модель на основе [sberbank-ai/rugpt3large_based_on_gpt2](https://huggingface.co/sberbank-ai/rugpt3large_based_on_gpt2)
9
  для восстановления полного текста реплик в диалоге из контекста.
10
 
11
 
12
+ ## Пример использования
13
+
14
+ ```
15
+ import torch
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
22
+ model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
23
+ model.to(device)
24
+
25
+ # На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
26
+ # В конце добавляем символ "#"
27
+ input_text = """<s>- Как тебя зовут?
28
+ - Джульетта Мао #"""
29
+ encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
30
+ encoded_prompt = encoded_prompt.to(device)
31
+
32
+ output_sequences = model.generate(
33
+ input_ids=encoded_prompt,
34
+ max_length=100,
35
+ temperature=1.0,
36
+ top_k=30,
37
+ top_p=0.85,
38
+ repetition_penalty=1.2,
39
+ do_sample=True,
40
+ num_return_sequences=1,
41
+ pad_token_id=0
42
+ )
43
+
44
+ generated_sequence = output_sequences[0].tolist()
45
+ text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
46
+ text = text[: text.find('</s>')]
47
+ text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
48
+ print(text)
49
+ ```