ai-forever commited on
Commit
24341af
1 Parent(s): a460b42

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -16
README.md CHANGED
@@ -163,25 +163,15 @@ We compare our solution with both open automatic spell checkers and the ChatGPT
163
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
164
 
165
  tokenizer = AutoTokenizer.from_pretrained("ai-forever/sage-fredt5-distilled-95m")
166
- model = AutoModelForSeq2SeqLM.from_pretrained("ai-forever/sage-fredt5-distilled-95m")
167
- model.to("cuda:0")
168
 
169
  sentence = "И не чсно прохожим в этот день непогожйи почему я веселый такйо"
170
- text = "<LM>" + sentence
171
- with torch.inference_mode():
172
- encodings = tokenizer(text, max_length=None, padding="longest", truncation=False, return_tensors="pt")
173
- for k, v in encodings.items():
174
- encodings[k] = v.to("cuda:0")
175
- res = model.generate(
176
- **encodings,
177
- use_cache=True,
178
- max_length = encodings["input_ids"].size(1) * 1.5
179
- )
180
- res = res.cpu().tolist()
181
- res = tokenizer.batch_decode(res, skip_special_tokens=True)
182
-
183
- print(res)
184
  # ["И не ясно прохожим в этот день непогожий, почему я весёлый такой?"]
 
185
  ```
186
 
187
  ## Limitations
 
163
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
164
 
165
  tokenizer = AutoTokenizer.from_pretrained("ai-forever/sage-fredt5-distilled-95m")
166
+ model = AutoModelForSeq2SeqLM.from_pretrained("ai-forever/sage-fredt5-distilled-95m", device_map='cuda')
 
167
 
168
  sentence = "И не чсно прохожим в этот день непогожйи почему я веселый такйо"
169
+ inputs = tokenizer(sentence, max_length=None, padding="longest", truncation=False, return_tensors="pt")
170
+ outputs = model.generate(**inputs.to(model.device), max_length = inputs["input_ids"].size(1) * 1.5)
171
+ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
172
+
 
 
 
 
 
 
 
 
 
 
173
  # ["И не ясно прохожим в этот день непогожий, почему я весёлый такой?"]
174
+
175
  ```
176
 
177
  ## Limitations