import functools import os import huggingface_hub from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizerFast, \ AutoModelForSequenceClassification, AutoTokenizer import random import paddlenlp.datasets from paddlenlp.data import Stack, Tuple, Pad from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer, ElectraModel import paddle import pandas as pd import re import tqdm import cachetools import typing import torch import torch.utils.data as Data import datetime from transformers import pipeline import gradio def collate(batch, tokenizer, useFirstDim=True): if useFirstDim: input_ids = torch.nn.utils.rnn.pad_sequence([_[0].data.get('input_ids')[0] for _ in batch], batch_first=True, padding_value=tokenizer.pad_token_id) else: input_ids = torch.nn.utils.rnn.pad_sequence([_.data.get('input_ids')[0] for _ in batch], batch_first=True, padding_value=tokenizer.pad_token_id) labels = [_[-1] for _ in batch] return torch.LongTensor(input_ids), torch.LongTensor(labels) def launchGradioNLI(): # 下载模型 folder = huggingface_hub.snapshot_download('qiaokuoyuan/symptom-sick-2c') # 读取模型和tokenizer model = XLMRobertaForSequenceClassification.from_pretrained(folder) tokenizer = XLMRobertaTokenizerFast.from_pretrained(folder) model.eval() # 定义补齐函数 _collate = functools.partial(collate, tokenizer=tokenizer) # 单个症状解析 def getSickDistributionTensorByOneSymptom(symptom, sick): # 需要将当前症状和每个疾病组成数组并 tokenzier tokens = [[symptom, sick], ] tokens = tokenizer(tokens, add_special_tokens=True, return_tensors='pt', padding=True, truncation='only_first') tokens = tokens.data.get('input_ids') batchOutputs = [] batchSize = 64 with paddle.no_grad(): for i in range(0, tokens.shape[0], batchSize): batch = tokens[i: i + batchSize] predict = model(batch) batchOutputs.append(predict.logits) batchOutputs = torch.cat(batchOutputs, dim=0) return str(batchOutputs[0][1].item()) # 单个症状解析 app = gradio.Interface(fn=getSickDistributionTensorByOneSymptom, inputs=['text', 'text'], outputs='text') app.launch() if __name__ == '__main__': launchGradioNLI()