|
import os |
|
import sys |
|
import uvicorn |
|
import torch |
|
from fastapi import Body, FastAPI |
|
from transformers import T5Tokenizer, MT5ForConditionalGeneration |
|
import pytorch_lightning as pl |
|
sys.path.append(os.path.abspath(os.path.join( |
|
os.path.dirname(__file__), os.path.pardir))) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = '5' |
|
os.environ["MASTER_ADDR"] = '127.0.0.1' |
|
os.environ["MASTER_PORT"] = '6000' |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
print('device') |
|
pretrain_model_path = '/cognitive_comp/ganruyi/hf_models/google/mt5-large' |
|
|
|
model_path = '/cognitive_comp/ganruyi/fengshen/mt5_large_summary/ckpt/epoch-0-last.ckpt' |
|
tokenizer = T5Tokenizer.from_pretrained(pretrain_model_path) |
|
print('load tokenizer') |
|
|
|
|
|
class MT5FinetuneSummary(pl.LightningModule): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.model = MT5ForConditionalGeneration.from_pretrained(pretrain_model_path) |
|
|
|
|
|
model = MT5FinetuneSummary.load_from_checkpoint(model_path) |
|
print('load checkpoint') |
|
model.to(device) |
|
model.eval() |
|
app = FastAPI() |
|
print('server start') |
|
|
|
|
|
|
|
|
|
@app.post('/mt5_summary') |
|
async def flask_gen(text: str = Body('', title='原文', embed=True), |
|
n_sample: int = 5, length: int = 32, is_beam_search=False): |
|
if len(text) > 128: |
|
text = text[:128] |
|
text = 'summary:'+text |
|
print(text) |
|
|
|
inputs = tokenizer.encode_plus( |
|
text, max_length=128, padding='max_length', truncation=True, return_tensors='pt') |
|
|
|
if is_beam_search: |
|
generated_ids = model.model.generate( |
|
input_ids=inputs['input_ids'].to(device), |
|
attention_mask=inputs['attention_mask'].to(device), |
|
max_length=length, |
|
num_beams=n_sample, |
|
repetition_penalty=2.5, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
num_return_sequences=n_sample |
|
) |
|
else: |
|
generated_ids = model.model.generate( |
|
input_ids=inputs['input_ids'].to(device), |
|
attention_mask=inputs['attention_mask'].to(device), |
|
max_length=length, |
|
do_sample=True, |
|
temperature=1.0, |
|
top_p=1.0, |
|
repetition_penalty=2.5, |
|
|
|
num_return_sequences=n_sample |
|
) |
|
result = [] |
|
|
|
for sample in generated_ids: |
|
preds = [tokenizer.decode(sample, skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True)] |
|
preds = ''.join(preds) |
|
|
|
result.append(preds) |
|
return result |
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run(app, host="0.0.0.0", port=6607, log_level="debug") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|