Bronco92's picture
Update app.py
e6b17ef
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
model = AutoModelForSequenceClassification.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
# Create a Streamlit input text box
input_text = st.text_input("Enter your text:")
# If input is provided
if input_text:
# Limit the input length
truncated_input = input_text[:512]
# Tokenize the input
tokens = tokenizer(truncated_input, truncation=True, padding=True, return_tensors="pt")
# Get model output
output = model(**tokens)
# The output of the model is a logits vector, so we take the argmax to get the predicted class index
predicted_class_idx = torch.argmax(output.logits, dim=-1).item()
st.write(f"Predicted class index: {predicted_class_idx}")