Spaces:
Sleeping
Sleeping
File size: 5,923 Bytes
b6ac0c6 ab64fd6 2f2c452 ab64fd6 2f2c452 a4c122a f1b5517 38c39a7 f1b5517 a4c122a f1b5517 2f2c452 ab64fd6 d9a6a4e 2f2c452 ab64fd6 2f2c452 6e9b058 2f2c452 6e9b058 2f2c452 6e9b058 2f2c452 6e9b058 2f2c452 6e9b058 2f2c452 ab64fd6 2f2c452 65003ad 6e9b058 a4c122a 2f2c452 ab64fd6 d9a6a4e 2f2c452 ab64fd6 a4c122a d9a6a4e 2f2c452 ab64fd6 2f2c452 d9a6a4e 6c2c9be d9a6a4e 2f2c452 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
import streamlit as st
from typing import List, Tuple
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Mapping of label to color
LABEL_COLORS = {
'LABEL-0': '#cccccc', # NONE
'LABEL-1': '#ffadad', # B-DATE
'LABEL-2': '#ebdb98', # I-DATE
'LABEL-3': '#586492', # B-TIME
'LABEL-4': '#ffb788', # I-TIME
'LABEL-5': '#76abbb', # B-DURATION
'LABEL-6': '#a0c4ff', # I-DURATION
'LABEL-7': '#f84252', # B-SET
'LABEL-8': '#ebdb98', # I-SET
}
LABEL_MEANINGS = {
'LABEL-0': 'NONE',
'LABEL-1': 'B-DATE',
'LABEL-2': 'I-DATE',
'LABEL-3': 'B-TIME',
'LABEL-4': 'I-TIME',
'LABEL-5': 'B-DURATION',
'LABEL-6': 'I-DURATION',
'LABEL-7': 'B-SET',
'LABEL-8': 'I-SET',
}
@st.cache_resource(show_spinner=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime')
model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime')
return tokenizer, model
def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
tokenizer, model = load_model()
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
with torch.no_grad():
outputs = model(**tokens)
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
labels = [model.config.id2label[pred] for pred in predictions]
word_ids = tokens.word_ids(batch_index=0)
input_ids = tokens["input_ids"][0]
entities = []
current_word_ids = []
current_label = None
last_word_id = None
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
label = labels[idx]
if word_id != last_word_id and current_word_ids:
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
entities.append((word, current_label))
current_word_ids = [idx]
current_label = label
else:
current_word_ids.append(idx)
current_label = label
last_word_id = word_id
if current_word_ids:
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
entities.append((word, current_label))
return entities
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
html = ''
for token, label in ner_result:
norm_label = label.replace('_', '-')
if norm_label != 'LABEL-0':
color = LABEL_COLORS.get(norm_label, '#eeeeee')
label_meaning = LABEL_MEANINGS.get(norm_label, norm_label)
html += (
f'<span class="ner-entity" style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;" '
f'data-tooltip="{label_meaning}">{token}</span> '
)
else:
html += f'{token} '
return html
def extract_entities(ner_result: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
# Group consecutive tokens with the same entity label (not LABEL-0)
entities = []
current_entity = []
current_label = None
for token, label in ner_result:
if label != 'LABEL-0':
if current_label == label:
current_entity.append(token)
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = [token]
current_label = label
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = []
current_label = None
if current_entity:
entities.append((' '.join(current_entity), current_label))
return entities
def legend_html() -> str:
html = '<div style="display:flex;flex-wrap:wrap;gap:8px;">'
for label, color in LABEL_COLORS.items():
if label == 'LABEL-0':
continue
meaning = LABEL_MEANINGS[label]
html += f'<span style="background-color:{color};padding:2px 8px;border-radius:4px;">{meaning} ({label})</span>'
html += '</div>'
return html
st.title('LLM-powered Named Entity Recognition (NER)')
st.markdown(
'''
<style>
.ner-entity {
position: relative;
cursor: pointer;
}
.ner-entity[data-tooltip]:hover:after {
content: attr(data-tooltip);
position: absolute;
left: 0;
top: 100%;
background: #222;
color: #fff;
padding: 2px 8px;
border-radius: 4px;
white-space: nowrap;
z-index: 10;
font-size: 0.9em;
margin-top: 2px;
}
</style>
''',
unsafe_allow_html=True
)
st.markdown('**Legend:**')
st.markdown(legend_html(), unsafe_allow_html=True)
user_text = st.text_area('Enter text for NER:', height=150)
if user_text:
ner_result = ner_with_robertime(user_text)
has_entity = any(label != 'LABEL-0' for _, label in ner_result)
if has_entity:
st.markdown('#### Entities Highlighted:')
st.markdown(colorize_entities(ner_result), unsafe_allow_html=True)
entities = extract_entities(ner_result)
if entities:
st.markdown('#### Detected Entities:')
for ent, label in entities:
norm_label = label.replace('_', '-')
st.markdown(f'- <span style="background-color:{LABEL_COLORS[norm_label]};padding:2px 8px;border-radius:4px;">{ent}</span> <span style="color:#888;">({LABEL_MEANINGS[norm_label]})</span>', unsafe_allow_html=True)
else:
st.info('No entities detected.')
else:
st.info('No entities detected.')
st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)') |