hahafofo's picture
add chatglm
390173a
raw
history blame
3.29 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from .singleton import Singleton
from transformers import (
EncoderDecoderModel,
AutoTokenizer
)
device = "cuda" if torch.cuda.is_available() else "cpu"
@Singleton
class Models(object):
def __getattr__(self, item):
if item in self.__dict__:
return getattr(self, item)
if item in ('zh2en_model', 'zh2en_tokenizer',):
self.zh2en_model, self.zh2en_tokenizer = self.load_zh2en_model()
if item in ('en2zh_model', 'en2zh_tokenizer',):
self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
if item in ('wenyanwen2modern_tokenizer', 'wenyanwen2modern_model',):
self.wenyanwen2modern_tokenizer, self.wenyanwen2modern_model = self.load_wenyanwen2modern_model()
return getattr(self, item)
@classmethod
def load_wenyanwen2modern_model(cls):
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = EncoderDecoderModel.from_pretrained(PRETRAINED)
return tokenizer, model
@classmethod
def load_en2zh_model(cls):
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
return en2zh_model, en2zh_tokenizer
@classmethod
def load_zh2en_model(cls):
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
return zh2en_model, zh2en_tokenizer,
models = Models.instance()
def wenyanwen2modern(text: str) -> str:
tk_kwargs = dict(
truncation=True,
max_length=128,
padding="max_length",
return_tensors='pt')
inputs = models.wenyanwen2modern_tokenizer([text, ], **tk_kwargs)
with torch.no_grad():
return models.wenyanwen2modern_tokenizer.batch_decode(
models.wenyanwen2modern_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
num_beams=3,
max_length=256,
bos_token_id=101,
eos_token_id=models.wenyanwen2modern_tokenizer.sep_token_id,
pad_token_id=models.wenyanwen2modern_tokenizer.pad_token_id,
), skip_special_tokens=True)[0].replace(" ", "")
def zh2en(text: str) -> str:
with torch.no_grad():
encoded = models.zh2en_tokenizer([text], return_tensors="pt")
sequences = models.zh2en_model.generate(**encoded)
return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
def en2zh(text: str) -> str:
with torch.no_grad():
encoded = models.en2zh_tokenizer([text], return_tensors="pt")
sequences = models.en2zh_model.generate(**encoded)
return models.en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
if __name__ == "__main__":
input = "飞流直下三千尺,疑是银河落九天"
input_m = wenyanwen2modern(input)
en = zh2en(input_m)
print(input, en)
zh = en2zh(en)
print(en, zh)