MadhuP commited on
Commit
3f79358
·
1 Parent(s): 7158a75

Update app.py

Browse files

adding a feature of :- per word classification

Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -65,8 +65,8 @@ def check_by_url(txt_url):
65
  new_data = {"title": title, "content": normalized_content_with_style}
66
  # return new_data
67
 
68
- model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
69
- tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
70
 
71
  label_mapping = {1: "SFW", 0: "NSFW"}
72
  test_encodings = tokenizer.encode_plus(
@@ -137,8 +137,8 @@ def predict_2(txt_url, normalized_content_with_style):
137
  new_data,
138
  ) = check_by_url(txt_url)
139
  elif txt_url.startswith(""):
140
- model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
141
- tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
142
 
143
  test_encodings = tokenizer.encode_plus(
144
  normalized_content_with_style,
@@ -185,8 +185,8 @@ def word_by_word(txt_url, normalized_content_with_style):
185
  confidence_score_text,
186
  ) = predict_2(txt_url, normalized_content_with_style)
187
 
188
- model = DistilBertForSequenceClassification.from_pretrained("/content/LoadModel")
189
- tokenizer = DistilBertTokenizer.from_pretrained("/content/LoadModel")
190
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
  model = model.to(device)
192
  model.eval()
@@ -251,7 +251,7 @@ demo = gr.Interface(
251
  gr.outputs.Textbox(label="Description").style(show_copy_button=True),
252
  gr.outputs.Textbox(label="Text_prediction_score"),
253
  gr.outputs.Textbox(label="Text_confidence_score"),
254
- gr.outputs.Textbox(label="word-by-word").style(show_copy_button=True),
255
  ],
256
  )
257
 
 
65
  new_data = {"title": title, "content": normalized_content_with_style}
66
  # return new_data
67
 
68
+ model = DistilBertForSequenceClassification.from_pretrained(".")
69
+ tokenizer = DistilBertTokenizer.from_pretrained(".")
70
 
71
  label_mapping = {1: "SFW", 0: "NSFW"}
72
  test_encodings = tokenizer.encode_plus(
 
137
  new_data,
138
  ) = check_by_url(txt_url)
139
  elif txt_url.startswith(""):
140
+ model = DistilBertForSequenceClassification.from_pretrained(".")
141
+ tokenizer = DistilBertTokenizer.from_pretrained(".")
142
 
143
  test_encodings = tokenizer.encode_plus(
144
  normalized_content_with_style,
 
185
  confidence_score_text,
186
  ) = predict_2(txt_url, normalized_content_with_style)
187
 
188
+ model = DistilBertForSequenceClassification.from_pretrained(".")
189
+ tokenizer = DistilBertTokenizer.from_pretrained(".")
190
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
  model = model.to(device)
192
  model.eval()
 
251
  gr.outputs.Textbox(label="Description").style(show_copy_button=True),
252
  gr.outputs.Textbox(label="Text_prediction_score"),
253
  gr.outputs.Textbox(label="Text_confidence_score"),
254
+ gr.outputs.Textbox(label="per word classification").style(show_copy_button=True),
255
  ],
256
  )
257