File size: 1,723 Bytes
d26f6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import DistilBertForTokenClassification, DistilBertTokenizer
import torch

DRIVE_BASE_PATH = "model/"
model_path = f"{DRIVE_BASE_PATH}"

model = DistilBertForTokenClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)

def predict_ner(input_text):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)

    # Make predictions
    with torch.no_grad():
        outputs = model(**inputs)

    # Process the NER results
    labels = outputs.logits.argmax(dim=2)
    predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
    # predicted_labels = [label_mapping.get(model.config.id2label[label_id], "O") for label_id in labels[0].tolist()]
    tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))

    # Pair tokens with their labels, excluding [SEP] and [CLS]
    token_label_pairs = [(token, label) for token, label in zip(tokenized_text, predicted_labels) if token not in ["[SEP]", "[CLS]"]]

    # Format the results vertically, excluding [SEP] and [CLS]
    formatted_results = []
    for token, label in token_label_pairs:
        formatted_results.append(f"Token: {token}, Label: {label}")

    return {"text": input_text, "formatted_results": formatted_results}


input_text = """Also , due to worsening renal function , she was started on octreotide / midodrine / albumin for hepatorenal
syndrome ( Cr 3.3 at its worst ) which resolved prior to her discharge ."""
result = predict_ner(input_text)
print(result['text'])
# print(result['token_probabilities'])
for item in result['formatted_results']:
    print(item)