Aristotle commited on
Commit
a768eaa
1 Parent(s): e26ed33
Files changed (5) hide show
  1. app.py +45 -19
  2. classes.pkl +2 -2
  3. model.pkl +3 -0
  4. requirements.txt +2 -3
  5. vectorizer.pkl +3 -0
app.py CHANGED
@@ -1,32 +1,58 @@
1
  import gradio as gr
2
- import onnxruntime
3
  import numpy as np
4
  import pickle
5
- threshold = 0.5
6
- onnx_session = onnxruntime.InferenceSession("bert.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Load the instance back
9
  with open('classes.pkl', 'rb') as file:
10
  mlb = pickle.load(file)
11
 
12
- with open('tokenizer.pkl', 'rb') as file:
13
- tokenizer = pickle.load(file)
14
 
 
 
15
 
16
- # Create a function to predict tags using the ONNX model
17
- def predict_tags_onnx(text):
18
- encoded_text = tokenizer(text , padding=True, truncation=True, return_tensors='pt')
19
- input_ids = encoded_text["input_ids"].numpy()
20
- attention_mask = encoded_text["attention_mask"].numpy()
21
-
22
- # Run the ONNX model
23
- outputs = np.asarray(onnx_session.run(None, {"input_ids": input_ids , "attention_mask":attention_mask}))
24
 
25
- # Post-process the outputs as needed
26
- #predicted_labels = torch.sigmoid(outputs).cpu().numpy()
27
- predicted_tags = mlb.classes_[np.where(np.squeeze((outputs > threshold).astype(int)).flatten() == 1)]
28
-
29
- return predicted_tags
30
 
31
- iface = gr.Interface(fn=predict_tags_onnx, inputs="text", outputs="text")
32
  iface.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import pickle
4
+
5
+
6
+
7
+ import nltk
8
+ from nltk.corpus import stopwords
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.stem import WordNetLemmatizer
11
+ from sklearn.feature_extraction.text import CountVectorizer
12
+
13
+ # Initialize NLTK resources (download if needed)
14
+ nltk.download('punkt')
15
+ nltk.download('wordnet')
16
+ nltk.download('stopwords')
17
+
18
+ # Text preprocessing functions
19
+
20
+ def preprocess_text(text):
21
+ # Tokenization
22
+ words = word_tokenize(text.lower()) # Convert to lowercase and tokenize
23
+
24
+ # Remove stopwords
25
+ stop_words = set(stopwords.words('english'))
26
+ words = [word for word in words if word not in stop_words]
27
+
28
+ # Lemmatization
29
+ lemmatizer = WordNetLemmatizer()
30
+ words = [lemmatizer.lemmatize(word) for word in words]
31
+
32
+ return ' '.join(words)
33
+
34
+
35
+
36
+
37
+
38
+ def predict_tags(text):
39
+ return mlb[np.where(model.predict(vectorizer.transform([preprocess_text(text)])).flatten() == 1)]
40
+
41
+
42
+
43
 
44
  # Load the instance back
45
  with open('classes.pkl', 'rb') as file:
46
  mlb = pickle.load(file)
47
 
48
+ with open('vectorizer.pkl', 'rb') as file:
49
+ vectorizer = pickle.load(file)
50
 
51
+ with open('model.pkl', 'rb') as file:
52
+ model = pickle.load(file)
53
 
 
 
 
 
 
 
 
 
54
 
55
+ # Create a function to predict tags using the ONNX model
 
 
 
 
56
 
57
+ iface = gr.Interface(fn=predict_tags, inputs="text", outputs="text")
58
  iface.launch()
classes.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:203390ce8d1e08f4ce6f73a1216738464c89d23caec7ec32e122936a26a90412
3
- size 922
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12ef1b82b64966b26fc03ac0f6567a673ef8751474b032c8d262b9a544925633
3
+ size 3192
model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c98ccde9742b241457b261f8fbe2e4d190f5ef011a9fa80c5a3dfb99adda165
3
+ size 22233200
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- onnxruntime==1.15.1
2
- torch==2.0.1
3
  scikit-learn==1.2.2
4
- transformers==4.21.2
 
 
 
 
1
  scikit-learn==1.2.2
2
+ nltk==3.8.1
3
+
vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48ea212d9f95d5829baabf000004a6a425cf834cfc0bb566e37dd58a211b89da
3
+ size 6554145