Spaces:
Runtime error
Runtime error
File size: 1,723 Bytes
d26f6fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import DistilBertForTokenClassification, DistilBertTokenizer
import torch
DRIVE_BASE_PATH = "model/"
model_path = f"{DRIVE_BASE_PATH}"
model = DistilBertForTokenClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
def predict_ner(input_text):
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
# Process the NER results
labels = outputs.logits.argmax(dim=2)
predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
# predicted_labels = [label_mapping.get(model.config.id2label[label_id], "O") for label_id in labels[0].tolist()]
tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))
# Pair tokens with their labels, excluding [SEP] and [CLS]
token_label_pairs = [(token, label) for token, label in zip(tokenized_text, predicted_labels) if token not in ["[SEP]", "[CLS]"]]
# Format the results vertically, excluding [SEP] and [CLS]
formatted_results = []
for token, label in token_label_pairs:
formatted_results.append(f"Token: {token}, Label: {label}")
return {"text": input_text, "formatted_results": formatted_results}
input_text = """Also , due to worsening renal function , she was started on octreotide / midodrine / albumin for hepatorenal
syndrome ( Cr 3.3 at its worst ) which resolved prior to her discharge ."""
result = predict_ner(input_text)
print(result['text'])
# print(result['token_probabilities'])
for item in result['formatted_results']:
print(item)
|