NiranjanShetty commited on
Commit
b6d08b7
1 Parent(s): 2ce634f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -50
app.py CHANGED
@@ -1,79 +1,66 @@
1
  import pandas as pd
2
  import torch
3
- from transformers import DistilBertTokenizer, DistilBertModel
4
  from symspellpy.symspellpy import SymSpell, Verbosity
5
  import streamlit as st
 
6
  import numpy as np
7
- import pickle
8
- import faiss
9
 
10
  # Load the dataset
11
- def load_drug_names(file_path):
12
- try:
13
- df = pd.read_csv(file_path)
14
- st.write("CSV Columns:", df.columns.tolist()) # Debugging line to print column names
15
- if 'drug_names' in df.columns:
16
- drug_names = df['drug_names'].dropna().tolist() # Drop NaN values
17
- return [name.lower() for name in drug_names]
18
- else:
19
- st.error("Column 'drug_names' not found in the CSV file. Please check the column names.")
20
- st.stop()
21
- except Exception as e:
22
- st.error(f"Error reading CSV file: {e}")
23
- st.stop()
24
 
25
- drug_names = load_drug_names('drug_names.csv')
 
26
 
27
- # Load DistilBERT model and tokenizer
28
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
29
- model = DistilBertModel.from_pretrained('distilbert-base-uncased')
30
 
31
  # Function to get embeddings
32
- def get_embeddings(texts):
33
- inputs = tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
- return outputs.last_hidden_state.mean(dim=1).numpy()
37
 
38
- # Check if precomputed embeddings exist
39
- try:
40
- with open('drug_embeddings.pkl', 'rb') as f:
41
- drug_embeddings = pickle.load(f)
42
- except FileNotFoundError:
43
- # Get embeddings for all drug names
44
- embeddings = get_embeddings(drug_names)
45
- drug_embeddings = np.vstack(embeddings)
46
-
47
- # Save embeddings for future use
48
- with open('drug_embeddings.pkl', 'wb') as f:
49
- pickle.dump(drug_embeddings, f)
50
-
51
- # Build FAISS index
52
- dimension = drug_embeddings.shape[1]
53
- index = faiss.IndexFlatL2(dimension)
54
- index.add(drug_embeddings)
55
 
56
  # Spell correction setup
57
  sym_spell = SymSpell(max_dictionary_edit_distance=2)
 
58
  for name in drug_names:
59
  sym_spell.create_dictionary_entry(name, 1)
60
 
61
  # Prediction function
62
  def predict_drug_name(input_text):
63
- input_text = input_text.lower().strip()
 
 
 
64
  suggestions = sym_spell.lookup(input_text, Verbosity.CLOSEST, max_edit_distance=2)
65
  if suggestions:
66
- corrected_text = suggestions[0].term
67
- else:
68
- corrected_text = input_text
69
 
70
- input_embedding = get_embeddings([corrected_text])
71
- distances, indices = index.search(input_embedding, 1)
72
- predicted_drug_name = drug_names[indices[0][0]]
 
 
 
 
 
 
73
 
74
- st.write(f"Input Text: {input_text}")
75
- st.write(f"Corrected Text: {corrected_text}")
76
- return predicted_drug_name
 
 
 
 
77
 
78
  # Streamlit app
79
  st.title("Doctor's Handwritten Prescription Prediction")
@@ -86,3 +73,16 @@ if st.button("Predict"):
86
  st.write(f"Predicted Drug Name: {predicted_drug_name}")
87
  else:
88
  st.write("Please enter a drug name to predict.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import torch
3
+ from transformers import BertTokenizer, BertModel
4
  from symspellpy.symspellpy import SymSpell, Verbosity
5
  import streamlit as st
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
  import numpy as np
 
 
8
 
9
  # Load the dataset
10
+ df = pd.read_csv('drug_names.csv')
11
+ drug_names = df['drug_names'].tolist()
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Preprocess the drug names
14
+ drug_names = [name.lower() for name in drug_names]
15
 
16
+ # Load BERT model and tokenizer
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+ model = BertModel.from_pretrained('bert-base-uncased')
19
 
20
  # Function to get embeddings
21
+ def get_embeddings(text):
22
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
+ return outputs.last_hidden_state.mean(dim=1)
26
 
27
+ # Get embeddings for all drug names
28
+ drug_embeddings = torch.vstack([get_embeddings(name) for name in drug_names])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Spell correction setup
31
  sym_spell = SymSpell(max_dictionary_edit_distance=2)
32
+ sym_spell.create_dictionary_entry("drug_name", 1)
33
  for name in drug_names:
34
  sym_spell.create_dictionary_entry(name, 1)
35
 
36
  # Prediction function
37
  def predict_drug_name(input_text):
38
+ input_text = input_text.lower()
39
+ input_embedding = get_embeddings(input_text)
40
+
41
+ # Correct spelling if necessary
42
  suggestions = sym_spell.lookup(input_text, Verbosity.CLOSEST, max_edit_distance=2)
43
  if suggestions:
44
+ input_text = suggestions[0].term
45
+ input_embedding = get_embeddings(input_text)
 
46
 
47
+ # Calculate similarity
48
+ similarities = cosine_similarity(input_embedding, drug_embeddings)
49
+ best_match_index = np.argmax(similarities)
50
+ return drug_names[best_match_index]
51
+
52
+ # Batch testing function
53
+ def test_model(test_file):
54
+ test_df = pd.read_csv(test_file)
55
+ correct_predictions = 0
56
 
57
+ for index, row in test_df.iterrows():
58
+ predicted_drug_name = predict_drug_name(row['input_text'])
59
+ if predicted_drug_name == row['correct_drug_name'].lower(): # Ensure case insensitivity
60
+ correct_predictions += 1
61
+
62
+ accuracy = (correct_predictions / len(test_df)) * 100
63
+ return accuracy
64
 
65
  # Streamlit app
66
  st.title("Doctor's Handwritten Prescription Prediction")
 
73
  st.write(f"Predicted Drug Name: {predicted_drug_name}")
74
  else:
75
  st.write("Please enter a drug name to predict.")
76
+
77
+ # Batch testing
78
+ st.header("Batch Testing")
79
+ uploaded_file = st.file_uploader("Choose a CSV file for batch testing", type="csv")
80
+ if uploaded_file is not None:
81
+ st.write("Uploaded file preview:")
82
+ test_df = pd.read_csv(uploaded_file)
83
+ st.write(test_df.head())
84
+
85
+ if st.button("Start Batch Testing"):
86
+ accuracy = test_model(uploaded_file)
87
+ st.write(f"Accuracy: {accuracy:.2f}%")
88
+