File size: 7,912 Bytes
2ed761d
59d5f39
0011926
2ed761d
 
0011926
 
 
2ed761d
0011926
 
 
 
 
2ed761d
0011926
2ed761d
0011926
 
 
 
2ed761d
 
 
 
0011926
 
 
 
120a364
2ed761d
 
 
0011926
 
 
120a364
0011926
2ed761d
59d5f39
2ed761d
59d5f39
0011926
 
 
 
 
2ed761d
120a364
 
2ed761d
 
120a364
0011926
17a8a7f
 
 
120a364
 
0011926
120a364
0011926
 
 
 
 
 
07b63ff
0011926
 
 
 
 
120a364
6eb3f68
0011926
 
 
 
 
 
 
 
 
 
 
 
 
2ed761d
0011926
 
 
120a364
0011926
 
 
 
 
 
 
 
17a8a7f
0011926
 
 
 
 
 
 
 
 
2ed761d
2e476d1
 
 
17a8a7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e476d1
2ed761d
2e476d1
17a8a7f
 
be47202
17a8a7f
 
be47202
17a8a7f
 
0011926
17a8a7f
2ed761d
 
85b55f8
 
 
cf081b6
0011926
98d8dd9
0011926
2ed761d
6eb3f68
2ed761d
0011926
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoModelForSequenceClassification
import re

# Маппинг классов NER
LABELS = ["TOPONYM", "MEDICINE", "SYMPTOM", "ALLERGEN", "BODY_PART"]
ID2LABEL = {
    0: "O",
    1: "B-TOPONYM", 2: "I-TOPONYM",
    3: "B-MEDICINE", 4: "I-MEDICINE",
    5: "B-SYMPTOM", 6: "I-SYMPTOM",
    7: "B-ALLERGEN", 8: "I-ALLERGEN",
    9: "B-BODY_PART", 10: "I-BODY_PART"
}
LABEL2ID = {v: k for k, v in ID2LABEL.items()}

# Маппинг классов RE
REL_LABELS = ["has_symptom", "has_medicine", "no_relation"]
REL_ID2LABEL = {i: l for i, l in enumerate(REL_LABELS)}
REL_LABEL2ID = {l: i for i, l in enumerate(REL_LABELS)}

# Загрузка моделей
def load_models():
    ner_model = AutoModelForTokenClassification.from_pretrained(
        "DanielNRU/pollen-ner",
        num_labels=len(ID2LABEL),
        id2label=ID2LABEL,
        label2id=LABEL2ID
    )
    ner_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
    re_model = AutoModelForSequenceClassification.from_pretrained(
        "DanielNRU/pollen-re",
        num_labels=len(REL_LABELS),
        id2label=REL_ID2LABEL,
        label2id=REL_LABEL2ID
    )
    re_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
    return ner_model, ner_tokenizer, re_model, re_tokenizer

ner_model, ner_tokenizer, re_model, re_tokenizer = load_models()

def split_sentences(text):
    # Примитивный сплиттер, можно заменить на nltk
    return [s.strip() for s in re.split(r'[.!?]', text) if s.strip()]

def predict_entities(text):
    inputs = ner_tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=512)
    offset_mapping = inputs.pop('offset_mapping')[0]
    with torch.no_grad():
        outputs = ner_model(**inputs)
        predictions = outputs.logits.argmax(dim=-1)[0]
    entities = []
    current = None
    for idx, (pred, (start, end)) in enumerate(zip(predictions, offset_mapping)):
        start = int(start)
        end = int(end)
        if start == 0 and end == 0:
            continue
        label = ID2LABEL[pred.item()]
        if label.startswith('B-'):
            if current:
                entities.append(current)
            current = {'text': text[start:end], 'label': label[2:], 'start': start, 'end': end}
        elif label.startswith('I-') and current and label[2:] == current['label']:
            current['text'] += text[start:end]
            current['end'] = end
        else:
            if current:
                entities.append(current)
                current = None
    if current:
        entities.append(current)
    return entities

def insert_entity_markers(text, ent1, ent2):
    # Вставляет маркеры вокруг двух сущностей (как в обучении RE)
    if ent1['start'] < ent2['start']:
        first, second = ent1, ent2
    else:
        first, second = ent2, ent1
    first_tag = f"[{first['label']}]"
    first_end_tag = f"[/{first['label']}]"
    second_tag = f"[{second['label']}]"
    second_end_tag = f"[/{second['label']}]"
    text_marked = text[:second['start']] + second_tag + text[second['start']:second['end']] + second_end_tag + text[second['end']:]
    text_marked = text_marked[:first['start']] + first_tag + text_marked[first['start']:first['end']] + first_end_tag + text_marked[first['end']:]
    return text_marked

def predict_relations(text, entities):
    # Для каждой пары BODY_PART–SYMPTOM и BODY_PART–MEDICINE внутри предложения
    sentences = split_sentences(text)
    relations = []
    for sent in sentences:
        sent_start = text.find(sent)
        sent_end = sent_start + len(sent)
        ents_in_sent = [e for e in entities if e['start'] >= sent_start and e['end'] <= sent_end]
        for ent1 in ents_in_sent:
            for ent2 in ents_in_sent:
                if ent1 == ent2:
                    continue
                if ent1['label'] == 'BODY_PART' and ent2['label'] in ['SYMPTOM', 'MEDICINE']:
                    marked = insert_entity_markers(text, ent1, ent2)
                    inputs = re_tokenizer(marked, return_tensors='pt', truncation=True, max_length=256)
                    with torch.no_grad():
                        logits = re_model(**inputs).logits
                        pred = logits.argmax(-1).item()
                        rel_label = REL_ID2LABEL[pred]
                    if rel_label != 'no_relation':
                        relations.append({'head': ent1, 'tail': ent2, 'relation': rel_label})
    return relations

def analyze_text(text):
    entities = predict_entities(text)
    relations = predict_relations(text, entities)
    # Для HighlightedText: список (substring, label)
    highlights = []
    used_spans = set()
    for ent in entities:
        # Не добавлять дублирующиеся спаны
        span = (ent['start'], ent['end'], ent['label'])
        if span not in used_spans:
            highlights.append((ent['text'], ent['label']))
            used_spans.add(span)
    # Для отдельных полей
    toponyms = [ent['text'] for ent in entities if ent['label'] == 'TOPONYM']
    medicines = [ent['text'] for ent in entities if ent['label'] == 'MEDICINE']
    # Для симптомов — только из отношений has_symptom
    symptoms = []
    for rel in relations:
        if rel['relation'] == 'has_symptom' and rel['tail']['label'] == 'SYMPTOM':
            symptoms.append(f"{rel['head']['text']} {rel['tail']['text']}")
    allergens = [ent['text'] for ent in entities if ent['label'] == 'ALLERGEN']
    return highlights, ', '.join(toponyms), ', '.join(medicines), ', '.join(symptoms), ', '.join(allergens)

interface = gr.Interface(
    fn=analyze_text,
    inputs=gr.Textbox(lines=5, label="Текст"),
    outputs=[
        gr.HighlightedText(label="Сущности"),
        gr.Textbox(label="Топонимы"),
        gr.Textbox(label="Медицинские препараты"),
        gr.Textbox(label="Симптомы"),
        gr.Textbox(label="Аллергены"),
    ],
    title="PollenNER — извлечение сущностей и отношений",
    description="""Извлечение топонимов, лекарств, аллергенов, частей тела и симптомов из сообщений пользователей Пыльца Club.""",
    examples=[
        ["В Московской области у меня началась аллергия на пыльцу березы, потекли глаза, нос, принимаю Зиртек и Назонекс."],
        ["У ребенка в Новокузнецке чешутся глаза, уши и течет нос, врач прописал Кромогексал, Назонекс в нос."],
        ["В Санкт-Петербурге началось цветение ольхи, сильная реакция, принимаю Эриус, но глаза все равно слезятся."],
        ["В Калининграде вам будет намного лучше чем в Питере , местами есть пылящие березы ,но это уже как капля в море ."],
        ["Сергиев посад - слава Богу дожди. Дочь глаза чешет, горло чешется как бы тоже говорит (а так с марта на монтелукасте и зиртеке. Назонекс убрала, т.к кровь из носа часто идет."]
    ],
    cache_examples=True,
    flagging_mode="never"
)

if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)