sumesh4C commited on
Commit
ca7ae72
·
verified ·
1 Parent(s): c36c35f

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +39 -1
tasks/text.py CHANGED
@@ -15,6 +15,40 @@ import pickle
15
  import torch
16
  import os
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  router = APIRouter()
19
 
20
  DESCRIPTION = "TF-IDF + RF"
@@ -71,8 +105,12 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  current_file_path = os.path.abspath(__file__)
72
  current_dir = os.path.dirname(current_file_path)
73
 
 
 
 
 
74
  # Make predictions using the loaded model
75
- predictions = predict(test_dataset, os.path.join(current_dir,"tf-idf_vectorizer.pkl") ,os.path.join(current_dir,"random_forest_model.pkl"))
76
  predictions = [LABEL_MAPPING[label] for label in predictions]
77
 
78
  #--------------------------------------------------------------------------------------------
 
15
  import torch
16
  import os
17
 
18
+ import nltk
19
+ from nltk.corpus import stopwords
20
+ import spacy
21
+
22
+ nltk.download('stopwords')
23
+ # Get the list of English stop words from NLTK
24
+ nltk_stop_words = stopwords.words('english')
25
+ # Load the spaCy model for English
26
+ nlp = spacy.load("en_core_web_sm")
27
+
28
+
29
+ def process_text(text):
30
+ """
31
+ Process text by:
32
+ 1. Lowercasing
33
+ 2. Removing punctuation and non-alphanumeric characters
34
+ 3. Removing stop words
35
+ 4. Lemmatization
36
+ """
37
+ # Step 1: Tokenization & Processing with spaCy
38
+ doc = nlp(text.lower()) # Process text with spaCy
39
+
40
+ # Step 2: Filter out stop words, non-alphanumeric characters, punctuation, and apply lemmatization
41
+ processed_tokens = [
42
+ re.sub(r'[^a-zA-Z0-9]', '', token.lemma_) # Remove non-alphanumeric characters
43
+ for token in doc
44
+ if token.text not in nltk_stop_words and token.text not in string.punctuation
45
+ ]
46
+
47
+ # Optional: Filter out empty strings resulting from the regex replacement
48
+ processed_tokens = " ".join([word for word in processed_tokens if word])
49
+
50
+ return processed_tokens
51
+
52
  router = APIRouter()
53
 
54
  DESCRIPTION = "TF-IDF + RF"
 
105
  current_file_path = os.path.abspath(__file__)
106
  current_dir = os.path.dirname(current_file_path)
107
 
108
+ with open(os.path.join(current_dir,"tf-idf_vectorizer.pkl"), "rb") as tfidf_file:
109
+ tfidf_vectorizer = pickle.load(tfidf_file)
110
+
111
+
112
  # Make predictions using the loaded model
113
+ predictions = predict(test_dataset,tfidf_vectorizer,os.path.join(current_dir,"random_forest_model.pkl"))
114
  predictions = [LABEL_MAPPING[label] for label in predictions]
115
 
116
  #--------------------------------------------------------------------------------------------