Ponimash commited on
Commit
bc628f9
1 Parent(s): 9ad4448

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +154 -1
README.md CHANGED
@@ -1,3 +1,156 @@
1
  ---
2
- license: apache-2.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: gpl-3.0
3
+ datasets:
4
+ - wikimedia/wikipedia
5
  ---
6
+
7
+ # FractalGPT/EmbedderDecoder
8
+
9
+ * **Оригинальная модель**
10
+ [[ai-forever/rugpt3small_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3small_based_on_gpt2)]
11
+
12
+ * **Код генерации взят частично отсюда**
13
+ [[vector2text](https://github.com/Koziev/vector2text)]
14
+
15
+ * Заменен эмбеддер
16
+ * Вместо нулей вектор дополняется квадратами чисел (далее можно кубами и т.д.)
17
+ * Создан класс для генератора
18
+ * Добавлен ранжировщик
19
+ * Заменена модель вместо large — small
20
+ * Убран top_p
21
+
22
+ * **Пример использования**
23
+
24
+
25
+ ```python
26
+ import torch
27
+ import numpy as np
28
+ from torch.nn import functional as F
29
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
30
+
31
+ def top_filtering(logits, top_k):
32
+ """
33
+ Фильтрация top-k, в фильтрации top-p в этой задаче особо смысла нет
34
+ """
35
+ assert logits.dim() == 1
36
+ top_k = min(top_k, logits.size(-1))
37
+ if top_k > 0:
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = -float('Inf')
40
+
41
+ return logits
42
+
43
+
44
+ class TextEmbdGenerator:
45
+ def __init__(self, name_or_path, sbert, device = None):
46
+ """
47
+ Инициализация генератора текста с моделью и токенизатором.
48
+ name_or_path: путь до модели токенизатора или ее имя для загрузки из Hugging Face.
49
+ sbert: модель для ранжирования (такая же что и создает эбеддинги)
50
+ """
51
+ self.device = device
52
+
53
+ if self.device == None:
54
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+
57
+ self.tokenizer = GPT2Tokenizer.from_pretrained(name_or_path)
58
+ self.model = GPT2LMHeadModel.from_pretrained(name_or_path).to(self.device)
59
+ self.sbert = sbert
60
+
61
+
62
+ def generate_embedding(self, embd, prompt = '', temperature=0.26, top_k=4, max_len=100):
63
+ """
64
+ Генерация текста на основе начального эмбеддинга и заданного начального текста.
65
+ """
66
+ vector = np.concatenate([embd,embd**2])
67
+ current_output_ids = self.tokenizer.encode(prompt)
68
+
69
+ embedding = torch.FloatTensor([list(vector)]).to(self.device)
70
+
71
+ while len(current_output_ids) < max_len:
72
+ with torch.no_grad():
73
+ token_embeddings = self.model.base_model.wte(torch.LongTensor(current_output_ids).to(self.device))
74
+ input_vectors = torch.vstack((embedding, token_embeddings)).unsqueeze(dim=0)
75
+ output_model = self.model(inputs_embeds=input_vectors)
76
+
77
+ logits = output_model.logits
78
+ if isinstance(logits, tuple):
79
+ logits = logits[0]
80
+ logits = logits[0, -1, :]
81
+ logits /= temperature
82
+ logits = top_filtering(logits, top_k)
83
+ probs = F.softmax(logits, dim=-1)
84
+
85
+ prev = torch.multinomial(probs, 1)
86
+ if prev.item() == self.tokenizer.eos_token_id:
87
+ break
88
+ current_output_ids.append(prev.item())
89
+
90
+ output_text = self.tokenizer.decode(current_output_ids)
91
+ return output_text.split('\n')[0]
92
+
93
+
94
+ def cosine_similarity(self, x, y):
95
+ """Вычисление косинусного сходства."""
96
+ return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
97
+
98
+ def generate_with_ranker(self, embd, prompt = '', seq=10, temperature=0.6, top_k=10, max_len=100):
99
+ """Генерация и ранжирование текста. Поумолчанию создаются 10 текстов"""
100
+ sequences = [self.generate_embedding(embd, prompt, temperature, top_k, max_len) for _ in range(seq)]
101
+ sequences = list(set(sequences)) # Удаление дубликатов
102
+
103
+ # Ранжирование
104
+ embeddings = self.sbert.encode(sequences)
105
+ similarities = [self.cosine_similarity(embd, emb) for emb in embeddings]
106
+ best_index = np.argmax(similarities)
107
+
108
+ return sequences[best_index]
109
+ ```
110
+
111
+ ---
112
+
113
+ ```bash
114
+ pip install sentence-transformers -q
115
+ ```
116
+
117
+ ```python
118
+ from sentence_transformers import SentenceTransformer
119
+
120
+ sbert = SentenceTransformer('FractalGPT/SbertDistil')
121
+ generator = TextEmbdGenerator('FractalGPT/EmbedderDecoder', sbert)
122
+ ```
123
+
124
+ ```python
125
+ embd = sbert.encode('там живут англичане')
126
+ generator.generate_with_ranker(embd, prompt = 'он всегда был в')
127
+ ```
128
+ ```bash
129
+ >>> он всегда был в Англии.
130
+ ```
131
+
132
+
133
+ ```python
134
+ embd = sbert.encode('там живут немцы')
135
+ generator.generate_with_ranker(embd, prompt = 'он всегда был в')
136
+ ```
137
+ ```bash
138
+ >>> он всегда был в Германии
139
+ ```
140
+
141
+ ```python
142
+ embd = sbert.encode('он сделает вывод на основе анализа ситуации')
143
+ generator.generate_with_ranker(embd)
144
+ ```
145
+ ```bash
146
+ >>> в процессе анализа ситуации необходимо выяснить:
147
+ ```
148
+
149
+
150
+ ```python
151
+ embd = sbert.encode('интересный фильм смотрел, фильм понравился')
152
+ generator.generate_with_ranker(embd, seq=5)
153
+ ```
154
+ ```bash
155
+ >>> фильм был снят по мотивам произведений
156
+ ```