ss108's picture
Update app.py
a8a3563 verified
raw
history blame contribute delete
No virus
1.96 kB
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
import gradio as gr
CASE_LABELS = [
"B-CASE_NAME",
"I-CASE_NAME",
"B-VOLUME",
"I-VOLUME",
"B-REPORTER",
"I-REPORTER",
"B-PAGE",
"I-PAGE",
"B-PIN",
"I-PIN",
"B-COURT",
"I-COURT",
"B-YEAR",
"I-YEAR",
]
STAT_LABELS = [
"B-TITLE",
"I-TITLE",
"B-CODE",
"I-CODE",
"B-SECTION",
"I-SECTION",
]
SHORT_LABELS = ["B-ID", "I-ID", "B-SUPRA", "I-SUPRA"]
ALL_LABELS = CASE_LABELS + STAT_LABELS + SHORT_LABELS + ["O"]
# Load your custom model and tokenizer
model_name = "ss108/legal-citation-bert"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define the function to perform NER using your model
def ner(text):
tokenized_input = tokenizer(text, return_tensors="pt", padding=True)
model.eval()
with torch.no_grad():
outputs = model(**tokenized_input)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
predicted_labels = [ALL_LABELS[p] for p in predictions[0].tolist()]
print(predicted_labels)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"][0])
entities = []
for token, label in zip(tokens, predicted_labels):
span = f"{token} - {label}"
print(span)
entities.append(span)
return entities
# Define examples
examples = [
"sentencing him to 24 months’ imprisonment on one count of possessing heroin with intent to distribute, 21 U.S.C. §§ 841(a); Fexler v. Hock, 123 U.S. 456, 499 (2021)",
"See, e.g., Morton, 456 F. 3d at 181 (affirming that ice cream is good)."
]
# Create Gradio interface
demo = gr.Interface(
ner,
gr.Textbox(placeholder="Enter sentence here..."),
"json", # Specify output type as JSON
examples=examples
)
# Launch the interface
demo.launch()