Bronco92 commited on
Commit
e6b17ef
1 Parent(s): 1bce6d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -1,3 +1,26 @@
1
- import gradio as gr
 
 
2
 
3
- gr.Interface.load("models/AkshatSurolia/ICD-10-Code-Prediction").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Load the model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
7
+ model = AutoModelForSequenceClassification.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
8
+
9
+ # Create a Streamlit input text box
10
+ input_text = st.text_input("Enter your text:")
11
+
12
+ # If input is provided
13
+ if input_text:
14
+ # Limit the input length
15
+ truncated_input = input_text[:512]
16
+
17
+ # Tokenize the input
18
+ tokens = tokenizer(truncated_input, truncation=True, padding=True, return_tensors="pt")
19
+
20
+ # Get model output
21
+ output = model(**tokens)
22
+
23
+ # The output of the model is a logits vector, so we take the argmax to get the predicted class index
24
+ predicted_class_idx = torch.argmax(output.logits, dim=-1).item()
25
+
26
+ st.write(f"Predicted class index: {predicted_class_idx}")