max-long's picture
Create app.py
f9bc688 verified
raw
history blame
3.26 kB
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
"TheBritishLibrary/blbooks",
split="train",
streaming=True, # Enable streaming
trust_remote_code=True
).shuffle(seed=42) # Shuffle added
# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
def ner(text: str, labels: str, threshold: float):
# Convert user-provided labels (comma-separated string) into a list
labels_list = [label.strip() for label in labels.split(",")]
# Predict entities using the fine-tuned GLiNER model
entities = model.predict_entities(text, labels_list, flat_ner=True, threshold=threshold)
# Prepare data for HighlightedText
highlighted_text = text
for ent in sorted(entities, key=lambda x: x['start'], reverse=True):
highlighted_text = (
highlighted_text[:ent['start']] +
f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
highlighted_text[ent['end']:]
)
return highlighted_text, entities
with gr.Blocks(title="General NER Demo") as demo:
gr.Markdown(
"""
# General Entity Recognition Demo
This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find.
"""
)
# Display a random example
input_text = gr.Textbox(
value="The machine is fed by means of an endless apron, the wool entering at the smaller end...",
label="Text input",
placeholder="Enter your text here",
lines=5
)
with gr.Row() as row:
labels = gr.Textbox(
value="Machine, Wool", # Default example labels
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5, # Adjusted to match the threshold used in the function
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
# Define output components
output_highlighted = gr.HTML(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
submit_btn = gr.Button("Find Entities!")
refresh_btn = gr.Button("Get New Snippet")
def get_new_snippet():
attempts = 0
max_attempts = 1000 # Prevent infinite loops
for sample in dataset_iter:
return sample['text']
return "No more snippets available." # Return this if no valid snippets are found
# Connect refresh button
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
# Connect submit button
submit_btn.click(
fn=lambda text, labels, threshold: ner(text, labels, threshold),
inputs=[input_text, labels, threshold],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True)