ss108 commited on
Commit
ba67dcd
1 Parent(s): fee9d2e

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +26 -0
inference.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
2
+ import torch
3
+
4
+ MODEL_NAME = "ss108/legal-citation-bert"
5
+ tokenizer = AutoTokenizer.from_pretrained("ss108/legal-citation-bert")
6
+ model = AutoModelForTokenClassification.from_pretrained("ss108/legal-citation-bert")
7
+ model.eval()
8
+
9
+ def predict(text):
10
+ # Tokenize the input text
11
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
12
+
13
+ # Get model predictions
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ logits = outputs.logits
17
+ predictions = torch.argmax(logits, dim=-1)
18
+
19
+ # Convert predictions to labels
20
+ labels = [model.config.id2label[pred.item()] for pred in predictions[0]]
21
+
22
+ # Align labels with tokens
23
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
24
+ result = [{'token': token, 'label': label} for token, label in zip(tokens, labels) if token not in tokenizer.all_special_tokens]
25
+
26
+ return result