NiranjanShetty commited on
Commit
bf3fea3
1 Parent(s): bed6748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -53
app.py CHANGED
@@ -1,62 +1,102 @@
1
- # app.py
2
- from fastapi import FastAPI, File, UploadFile
3
- from pydantic import BaseModel
4
- import numpy as np
5
  import pandas as pd
6
- from PIL import Image
7
- import io
8
  import torch
9
- from torchvision import transforms
10
- from sklearn.feature_extraction.text import TfidfVectorizer
 
11
  from sklearn.metrics.pairwise import cosine_similarity
12
- import uvicorn
13
- from models import CRNN # Replace with your actual CRNN model import
14
-
15
- app = FastAPI()
16
-
17
- # Load drug names and create a TF-IDF model
18
- drug_names = pd.read_csv('drug_names.csv')['drug_name'].tolist()
19
- vectorizer = TfidfVectorizer()
20
- drug_name_vectors = vectorizer.fit_transform(drug_names)
21
-
22
- # Load pre-trained CRNN model
23
- model = CRNN() # Replace with your CRNN model initialization
24
- model.load_state_dict(torch.load('crnn_model.pth'))
25
- model.eval()
26
-
27
- # Define image transformation
28
- transform = transforms.Compose([
29
- transforms.Grayscale(),
30
- transforms.Resize((128, 32)),
31
- transforms.ToTensor(),
32
- transforms.Normalize((0.5,), (0.5,))
33
- ])
34
-
35
- class Prediction(BaseModel):
36
- predicted_drug_name: str
37
-
38
- @app.post("/predict", response_model=Prediction)
39
- async def predict(file: UploadFile = File(...)):
40
- image = Image.open(io.BytesIO(await file.read()))
41
- image = transform(image).unsqueeze(0)
42
-
43
  with torch.no_grad():
44
- output = model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Convert output to text (assuming you have a decoding function)
47
- recognized_text = decode_output(output) # Replace with actual decoding logic
 
 
 
48
 
49
- # Predict drug name
50
- input_vector = vectorizer.transform([recognized_text])
51
- similarities = cosine_similarity(input_vector, drug_name_vectors)
52
- best_match_idx = np.argmax(similarities)
53
- predicted_drug_name = drug_names[best_match_idx]
54
 
55
- return Prediction(predicted_drug_name=predicted_drug_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def decode_output(output):
58
- # Add your decoding logic here
59
- return "Sample Drug Name"
 
 
 
 
 
60
 
61
- if __name__ == "__main__":
62
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
 
 
2
  import torch
3
+ from transformers import BertTokenizer, BertModel
4
+ from symspellpy.symspellpy import SymSpell, Verbosity
5
+ import streamlit as st
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ import numpy as np
8
+
9
+ # Load the dataset
10
+ try:
11
+ df = pd.read_csv('drug_names.csv')
12
+ st.write("CSV Columns:", df.columns.tolist()) # Debugging line to print column names
13
+ if 'drug_names' in df.columns:
14
+ drug_names = df['drug_names'].tolist()
15
+ else:
16
+ st.error("Column 'drug_names' not found in the CSV file. Please check the column names.")
17
+ st.stop()
18
+ except Exception as e:
19
+ st.error(f"Error reading CSV file: {e}")
20
+ st.stop()
21
+
22
+ # Preprocess the drug names
23
+ drug_names = [name.lower() for name in drug_names]
24
+
25
+ # Load BERT model and tokenizer
26
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
27
+ model = BertModel.from_pretrained('bert-base-uncased')
28
+
29
+ # Function to get embeddings
30
+ def get_embeddings(text):
31
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
 
 
 
 
 
 
32
  with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ return outputs.last_hidden_state.mean(dim=1)
35
+
36
+ # Get embeddings for all drug names
37
+ drug_embeddings = torch.vstack([get_embeddings(name) for name in drug_names])
38
+
39
+ # Spell correction setup
40
+ sym_spell = SymSpell(max_dictionary_edit_distance=2)
41
+ sym_spell.create_dictionary_entry("drug_name", 1)
42
+ for name in drug_names:
43
+ sym_spell.create_dictionary_entry(name, 1)
44
+
45
+ # Prediction function
46
+ def predict_drug_name(input_text):
47
+ input_text = input_text.lower()
48
+ input_embedding = get_embeddings(input_text)
49
 
50
+ # Correct spelling if necessary
51
+ suggestions = sym_spell.lookup(input_text, Verbosity.CLOSEST, max_edit_distance=2)
52
+ if suggestions:
53
+ input_text = suggestions[0].term
54
+ input_embedding = get_embeddings(input_text)
55
 
56
+ # Calculate similarity
57
+ similarities = cosine_similarity(input_embedding, drug_embeddings)
58
+ best_match_index = np.argmax(similarities)
59
+ return drug_names[best_match_index]
 
60
 
61
+ # Batch testing function
62
+ def test_model(test_file):
63
+ test_df = pd.read_csv(test_file)
64
+ st.write("Test CSV Columns:", test_df.columns.tolist()) # Debugging line to print column names
65
+ if 'input_text' not in test_df.columns or 'correct_drug_name' not in test_df.columns:
66
+ st.error("Test file must contain 'input_text' and 'correct_drug_name' columns.")
67
+ return None
68
+
69
+ correct_predictions = 0
70
+ for index, row in test_df.iterrows():
71
+ predicted_drug_name = predict_drug_name(row['input_text'])
72
+ if predicted_drug_name == row['correct_drug_name'].lower(): # Ensure case insensitivity
73
+ correct_predictions += 1
74
+
75
+ accuracy = (correct_predictions / len(test_df)) * 100
76
+ return accuracy
77
+
78
+ # Streamlit app
79
+ st.title("Doctor's Handwritten Prescription Prediction")
80
 
81
+ # Single input prediction
82
+ input_text = st.text_input("Enter the partial or misspelled drug name:")
83
+ if st.button("Predict"):
84
+ if input_text:
85
+ predicted_drug_name = predict_drug_name(input_text)
86
+ st.write(f"Predicted Drug Name: {predicted_drug_name}")
87
+ else:
88
+ st.write("Please enter a drug name to predict.")
89
 
90
+ # Batch testing
91
+ st.header("Batch Testing")
92
+ uploaded_file = st.file_uploader("Choose a CSV file for batch testing", type="csv")
93
+ if uploaded_file is not None:
94
+ st.write("Uploaded file preview:")
95
+ test_df = pd.read_csv(uploaded_file)
96
+ st.write(test_df.head())
97
+ st.write("Test CSV Columns:", test_df.columns.tolist()) # Debugging line to print column names
98
+
99
+ if st.button("Start Batch Testing"):
100
+ accuracy = test_model(uploaded_file)
101
+ if accuracy is not None:
102
+ st.write(f"Accuracy: {accuracy:.2f}%")