henry2024 commited on
Commit
f0c31b4
1 Parent(s): b070a6c

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -15,7 +15,10 @@ from timeit import default_timer as timer
15
  from typing import Tuple, Dict
16
  from sklearn.feature_extraction.text import TfidfVectorizer
17
  from spacy.lang.de.stop_words import STOP_WORDS
18
- vectorizer = TfidfVectorizer(stop_words=list(STOP_WORDS))
 
 
 
19
  '''
20
  import nltk
21
  from nltk.corpus import stopwords
@@ -94,8 +97,8 @@ train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 )
94
  train_data['label'].value_counts().sort_index()
95
  test_data['label'].value_counts().sort_index()
96
  vectorizer.fit(train_data.text)
97
- vectorizer.get_feature_names_out()[: 100]
98
- vectorizer= vectorizer
99
  #########################################################################################################################
100
  if torch.cuda.is_available():
101
  device = "cuda"
@@ -169,9 +172,9 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
169
  #base_model = /home/henry/Desktop/ARIN_7102/download/phi-2 # gru_model # embedder = SentenceTransformer("/home/henry/Desktop/ARIN_7102/download/bge-small-en-v1.5", device="cuda")
170
  if base_model == "gru_model":
171
  # Model and transforms preparation
172
- model= RNN_model().to(device)
173
  # Load state dict
174
- model.load_state_dict(torch.load(f= 'pretrained_gru_model.pth', map_location= device))
175
  # Random greetings in list format
176
  greetings = ["hello!",'hello', 'hii !', 'hi', "hi there!", "hi there!", "heyy", 'good morning', 'good afternoon', 'good evening', "hey", "how are you", "how are you?", "how is it going", "how is it going?", "what's up?",
177
  "how are you?", "hey, how are you?", "what is popping", "good to see you!", "howdy!", "hi, nice to meet you.", "hiya!", "hi", "hi, what's new?", "hey, how's your day?", "hi, how have you been?", "greetings"]
 
15
  from typing import Tuple, Dict
16
  from sklearn.feature_extraction.text import TfidfVectorizer
17
  from spacy.lang.de.stop_words import STOP_WORDS
18
+ from model import ImprovedGRUModel
19
+ import nltk_utils
20
+ # vectorizer = TfidfVectorizer(stop_words=list(STOP_WORDS))
21
+ vectorizer= nltk_utils.vectorizer()
22
  '''
23
  import nltk
24
  from nltk.corpus import stopwords
 
97
  train_data['label'].value_counts().sort_index()
98
  test_data['label'].value_counts().sort_index()
99
  vectorizer.fit(train_data.text)
100
+ # vectorizer.get_feature_names_out()[: 100]
101
+ # vectorizer= vectorizer
102
  #########################################################################################################################
103
  if torch.cuda.is_available():
104
  device = "cuda"
 
172
  #base_model = /home/henry/Desktop/ARIN_7102/download/phi-2 # gru_model # embedder = SentenceTransformer("/home/henry/Desktop/ARIN_7102/download/bge-small-en-v1.5", device="cuda")
173
  if base_model == "gru_model":
174
  # Model and transforms preparation
175
+ model= ImprovedGRUModel().to(device)
176
  # Load state dict
177
+ model.load_state_dict(torch.load(f= 'gru_model.pth', map_location= device))
178
  # Random greetings in list format
179
  greetings = ["hello!",'hello', 'hii !', 'hi', "hi there!", "hi there!", "heyy", 'good morning', 'good afternoon', 'good evening', "hey", "how are you", "how are you?", "how is it going", "how is it going?", "what's up?",
180
  "how are you?", "hey, how are you?", "what is popping", "good to see you!", "howdy!", "hi, nice to meet you.", "hiya!", "hi", "hi, what's new?", "hey, how's your day?", "hi, how have you been?", "greetings"]