|
from transformers import pipeline |
|
|
|
baseline_classifier = pipeline("ner", |
|
model="Dagobert42/biored-finetuned", |
|
aggregation_strategy="simple" |
|
) |
|
augmented_classifier = pipeline("ner", |
|
model="Dagobert42/biored-augmented", |
|
aggregation_strategy="simple" |
|
) |
|
|
|
def annotate_sentence(sentence, predictions): |
|
colors = { |
|
'null': '#bfbfbf', |
|
'GeneOrGeneProduct': '#aad4aa', |
|
'DiseaseOrPhenotypicFeature': '#f8b400', |
|
'ChemicalEntity': '#a4c2f4', |
|
'OrganismTaxon': '#ffb6c1', |
|
'SequenceVariant': '#e2b0ff', |
|
'CellLine': '#ffcc99' |
|
} |
|
output = [] |
|
i = 0 |
|
for p in predictions: |
|
if sentence[i:p['start']] != '': |
|
output.append(sentence[i:p['start']]) |
|
output.append((p['word'], p['entity_group'], colors[p['entity_group']])) |
|
i = p['end'] |
|
if sentence[p['end']:]: |
|
output.append(sentence[p['end']:]) |
|
return output |