|
import random |
|
from gliner import GLiNER |
|
import gradio as gr |
|
from datasets import load_dataset |
|
|
|
|
|
dataset_iter = load_dataset( |
|
"TheBritishLibrary/blbooks", |
|
split="train", |
|
streaming=True, |
|
trust_remote_code=True |
|
).shuffle(seed=42) |
|
|
|
|
|
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True) |
|
|
|
def ner(text: str, labels: str, threshold: float): |
|
|
|
labels_list = [label.strip() for label in labels.split(",")] |
|
|
|
|
|
entities = model.predict_entities(text, labels_list, flat_ner=True, threshold=threshold) |
|
|
|
|
|
highlighted_text = text |
|
for ent in sorted(entities, key=lambda x: x['start'], reverse=True): |
|
highlighted_text = ( |
|
highlighted_text[:ent['start']] + |
|
f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" + |
|
highlighted_text[ent['end']:] |
|
) |
|
|
|
return highlighted_text, entities |
|
|
|
with gr.Blocks(title="General NER Demo") as demo: |
|
gr.Markdown( |
|
""" |
|
# General Entity Recognition Demo |
|
This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find. |
|
""" |
|
) |
|
|
|
|
|
input_text = gr.Textbox( |
|
value="The machine is fed by means of an endless apron, the wool entering at the smaller end...", |
|
label="Text input", |
|
placeholder="Enter your text here", |
|
lines=5 |
|
) |
|
|
|
with gr.Row() as row: |
|
labels = gr.Textbox( |
|
value="Machine, Wool", |
|
label="Labels", |
|
placeholder="Enter your labels here (comma separated)", |
|
scale=2, |
|
) |
|
threshold = gr.Slider( |
|
0, |
|
1, |
|
value=0.5, |
|
step=0.01, |
|
label="Threshold", |
|
info="Lower the threshold to increase how many entities get predicted.", |
|
scale=1, |
|
) |
|
|
|
|
|
output_highlighted = gr.HTML(label="Predicted Entities") |
|
output_entities = gr.JSON(label="Entities") |
|
|
|
submit_btn = gr.Button("Find Entities!") |
|
refresh_btn = gr.Button("Get New Snippet") |
|
|
|
def get_new_snippet(): |
|
attempts = 0 |
|
max_attempts = 1000 |
|
for sample in dataset_iter: |
|
return sample['text'] |
|
return "No more snippets available." |
|
|
|
|
|
refresh_btn.click(fn=get_new_snippet, outputs=input_text) |
|
|
|
|
|
submit_btn.click( |
|
fn=lambda text, labels, threshold: ner(text, labels, threshold), |
|
inputs=[input_text, labels, threshold], |
|
outputs=[output_highlighted, output_entities] |
|
) |
|
|
|
demo.queue() |
|
demo.launch(debug=True) |