Spaces:
Runtime error
Runtime error
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
import torch | |
import gradio as gr | |
CASE_LABELS = [ | |
"B-CASE_NAME", | |
"I-CASE_NAME", | |
"B-VOLUME", | |
"I-VOLUME", | |
"B-REPORTER", | |
"I-REPORTER", | |
"B-PAGE", | |
"I-PAGE", | |
"B-PIN", | |
"I-PIN", | |
"B-COURT", | |
"I-COURT", | |
"B-YEAR", | |
"I-YEAR", | |
] | |
STAT_LABELS = [ | |
"B-TITLE", | |
"I-TITLE", | |
"B-CODE", | |
"I-CODE", | |
"B-SECTION", | |
"I-SECTION", | |
] | |
SHORT_LABELS = ["B-ID", "I-ID", "B-SUPRA", "I-SUPRA"] | |
ALL_LABELS = CASE_LABELS + STAT_LABELS + SHORT_LABELS + ["O"] | |
# Load your custom model and tokenizer | |
model_name = "ss108/legal-citation-bert" | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Define the function to perform NER using your model | |
def ner(text): | |
tokenized_input = tokenizer(text, return_tensors="pt", padding=True) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**tokenized_input) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=-1) | |
predicted_labels = [ALL_LABELS[p] for p in predictions[0].tolist()] | |
print(predicted_labels) | |
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"][0]) | |
entities = [] | |
for token, label in zip(tokens, predicted_labels): | |
span = f"{token} - {label}" | |
print(span) | |
entities.append(span) | |
return entities | |
# Define examples | |
examples = [ | |
"sentencing him to 24 months’ imprisonment on one count of possessing heroin with intent to distribute, 21 U.S.C. §§ 841(a); Fexler v. Hock, 123 U.S. 456, 499 (2021)", | |
"See, e.g., Morton, 456 F. 3d at 181 (affirming that ice cream is good)." | |
] | |
# Create Gradio interface | |
demo = gr.Interface( | |
ner, | |
gr.Textbox(placeholder="Enter sentence here..."), | |
"json", # Specify output type as JSON | |
examples=examples | |
) | |
# Launch the interface | |
demo.launch() |