MadhuP commited on
Commit
e6a0986
·
1 Parent(s): c83a94f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -65,8 +65,8 @@ def check_by_url(txt_url):
65
  new_data = {"title": title, "content": normalized_content_with_style}
66
  # return new_data
67
 
68
- model = DistilBertForSequenceClassification.from_pretrained(".")
69
- tokenizer = DistilBertTokenizer.from_pretrained(".")
70
 
71
  label_mapping = {1: "SFW", 0: "NSFW"}
72
  test_encodings = tokenizer.encode_plus(
@@ -125,7 +125,9 @@ def predict_2(txt_url, normalized_content_with_style):
125
  confidence_scores_content,
126
  new_data,
127
  ) = (None, None, None, None, None)
 
128
  predicted_label_text, confidence_score_text = None, None
 
129
  if txt_url.startswith("http://") or txt_url.startswith("https://"):
130
  (
131
  predicted_label_title,
@@ -135,8 +137,8 @@ def predict_2(txt_url, normalized_content_with_style):
135
  new_data,
136
  ) = check_by_url(txt_url)
137
  elif txt_url.startswith(""):
138
- model = DistilBertForSequenceClassification.from_pretrained(".")
139
- tokenizer = DistilBertTokenizer.from_pretrained(".")
140
 
141
  test_encodings = tokenizer.encode_plus(
142
  normalized_content_with_style,
@@ -145,6 +147,7 @@ def predict_2(txt_url, normalized_content_with_style):
145
  max_length=512,
146
  return_tensors="pt",
147
  )
 
148
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149
  test_input_ids = test_encodings["input_ids"].to(device)
150
  test_attention_mask = test_encodings["attention_mask"].to(device)
@@ -169,6 +172,7 @@ def predict_2(txt_url, normalized_content_with_style):
169
  confidence_score_text,
170
  #new,
171
  )
 
172
  def word_by_word(txt_url, normalized_content_with_style):
173
  if txt_url.startswith("http://") or txt_url.startswith("https://") or txt_url.startswith(""):
174
  (
@@ -181,8 +185,8 @@ def word_by_word(txt_url, normalized_content_with_style):
181
  confidence_score_text,
182
  ) = predict_2(txt_url, normalized_content_with_style)
183
 
184
- model = DistilBertForSequenceClassification.from_pretrained("")
185
- tokenizer = DistilBertTokenizer.from_pretrained(".")
186
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
  model = model.to(device)
188
  model.eval()
@@ -231,6 +235,8 @@ def word_by_word(txt_url, normalized_content_with_style):
231
  confidence_score_text,
232
  new_word,
233
  )
 
 
234
  demo = gr.Interface(
235
  fn=word_by_word,
236
  inputs=[
@@ -248,4 +254,5 @@ demo = gr.Interface(
248
  gr.outputs.Textbox(label="word-by-word").style(show_copy_button=True),
249
  ],
250
  )
251
- demo.launch()
 
 
65
  new_data = {"title": title, "content": normalized_content_with_style}
66
  # return new_data
67
 
68
+ model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
69
+ tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
70
 
71
  label_mapping = {1: "SFW", 0: "NSFW"}
72
  test_encodings = tokenizer.encode_plus(
 
125
  confidence_scores_content,
126
  new_data,
127
  ) = (None, None, None, None, None)
128
+
129
  predicted_label_text, confidence_score_text = None, None
130
+
131
  if txt_url.startswith("http://") or txt_url.startswith("https://"):
132
  (
133
  predicted_label_title,
 
137
  new_data,
138
  ) = check_by_url(txt_url)
139
  elif txt_url.startswith(""):
140
+ model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
141
+ tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
142
 
143
  test_encodings = tokenizer.encode_plus(
144
  normalized_content_with_style,
 
147
  max_length=512,
148
  return_tensors="pt",
149
  )
150
+
151
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
  test_input_ids = test_encodings["input_ids"].to(device)
153
  test_attention_mask = test_encodings["attention_mask"].to(device)
 
172
  confidence_score_text,
173
  #new,
174
  )
175
+
176
  def word_by_word(txt_url, normalized_content_with_style):
177
  if txt_url.startswith("http://") or txt_url.startswith("https://") or txt_url.startswith(""):
178
  (
 
185
  confidence_score_text,
186
  ) = predict_2(txt_url, normalized_content_with_style)
187
 
188
+ model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
189
+ tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
190
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
  model = model.to(device)
192
  model.eval()
 
235
  confidence_score_text,
236
  new_word,
237
  )
238
+
239
+
240
  demo = gr.Interface(
241
  fn=word_by_word,
242
  inputs=[
 
254
  gr.outputs.Textbox(label="word-by-word").style(show_copy_button=True),
255
  ],
256
  )
257
+
258
+ demo.launch(debug=True, share= True)