EricaCorral commited on
Commit
d4f6c5c
1 Parent(s): 94afdd8

Marian didn't work, rolled back version to autotokenizer

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,18 +1,19 @@
1
  from pypinyin import pinyin
2
- from transformers import MarianMTModel, MarianTokenizer
3
  from LAC import LAC
4
  import gradio as gr
5
  import torch
6
 
7
- model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
8
  model.eval()
9
- tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
10
  lac = LAC(mode="seg")
11
 
12
  def make_request(chinese_text):
13
  with torch.no_grad():
14
- generated_tokens = model.generate(**tokenizer(chinese_text, return_tensors="pt", padding=True))
15
- return [tokenizer.decode(generated_tokens, skip_special_tokens=True) for t in generated_tokens]
 
16
 
17
  def generatepinyin(input):
18
  pinyin_list = pinyin(input)
 
1
  from pypinyin import pinyin
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from LAC import LAC
4
  import gradio as gr
5
  import torch
6
 
7
+ model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
8
  model.eval()
9
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
10
  lac = LAC(mode="seg")
11
 
12
  def make_request(chinese_text):
13
  with torch.no_grad():
14
+ encoded_zh = tokenizer.prepare_seq2seq_batch([chinese_text], return_tensors="pt")
15
+ generated_tokens = model.generate(**encoded_zh)
16
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
17
 
18
  def generatepinyin(input):
19
  pinyin_list = pinyin(input)