File size: 3,768 Bytes
3c82d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288b444
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from transformers import RobertaForTokenClassification, AutoTokenizer
model=RobertaForTokenClassification.from_pretrained('jiangchengchengNLP/Chinese_resume_extract')
tokenizer = AutoTokenizer.from_pretrained('jiangchengchengNLP/Chinese_resume_extract',do_lower_case=True)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
model.to(device)
import json
label_list={
        0:'其他',
        1:'电话',
        2:'毕业时间', #毕业时间
        3:'出生日期',  #出生日期
        4:'项目名称',  #项目名称
        5:'毕业院校',  #毕业院校
        6:'职务',  #职务
        7:'籍贯',  #籍贯
        8:'学位',  #学位
        9:'性别',  #性别
        10:'姓名', #姓名
        11:'工作时间', #工作时间
        12:'落户市县', #落户市县
        13:'项目时间',  #项目时间
        14:'最高学历', #最高学历
        15:'工作单位',  #工作单位
        16:'政治面貌',  #政治面貌
        17:'工作内容', #工作内容
        18:'项目责任',  #项目责任
    }

def get_info(text):
    #文本处理
    text=text.strip()
    text=text.replace('\n',',')  # 将换行符替换为逗号
    text=text.replace('\r',',')  # 将回车符替换为逗号
    text=text.replace('\t',',')  # 将制表符替换为逗号
    text=text.replace(' ',',')  # 将空格替换为逗号
    #将连续的逗号合并成一个逗号
    while ',,' in text:
        text=text.replace(',,',',')
    block_list=[]
    if len(text)>300:
        #切块策略
        #先切分成句
        sentence_list=text.split(',')
        #然后拼接句子长度不超过300,一旦超过300,当前句子放到下一个块中
        boundary=300
        block_list=[]
        block=sentence_list[0]
        for i in range(1,len(sentence_list)):
            if len(block)+len(sentence_list[i])<=boundary:
                block+=sentence_list[i]
            else:
                block_list.append(block)
                block=sentence_list[i]
        block_list.append(block)
    else:
        block_list.append(text)
    _input = tokenizer(block_list, return_tensors='pt',padding=True,truncation=True)
    #如果有GPU,将输入数据移到GPU
    input_ids = _input['input_ids'].to(device)
    attention_mask = _input['attention_mask'].to(device)
    # 模型推理
    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]

    # 获取预测的标签ID
    #print(logits.shape)
    ids = torch.argmax(logits, dim=-1)
    input_ids=input_ids.reshape(-1)
    #将张量在最后一个维度拼接,并以0为分界,拼接成句
    ids =ids.reshape(-1)
    # 按标签组合成提取内容
    extracted_info = {}
    word_list=[]
    flag=None
    for idx, label_id in enumerate(ids):
        label_id = label_id.item()
        if  label_id!= 0 and (flag==None or flag==label_id):  #不等于零时
            if flag==None:
                flag=label_id
            label = label_list[label_id]  # 获取对应的标签
            word_list.append(input_ids[idx].item())
            if label not in extracted_info:
                extracted_info[label] = []
        else:
            if word_list:
                sentence=''.join(tokenizer.decode(word_list))
                extracted_info[label].append(sentence)
                flag=None
            word_list=[]
            if label_id!= 0:
                label = label_list[label_id]  # 获取对应的标签
                word_list.append(input_ids[idx].item())
                if label not in extracted_info:
                    extracted_info[label] = []   
    # 返回JSON格式的提取内容
    return extracted_info