qiuhuachuan commited on
Commit
66536b3
1 Parent(s): 1613789

Upload local_use.py

Browse files
Files changed (1) hide show
  1. local_use.py +47 -19
local_use.py CHANGED
@@ -6,10 +6,10 @@ from torch import nn
6
 
7
  label_mapping = {0: 'NSFW', 1: 'SFW'}
8
 
9
- config = BertConfig.from_pretrained('qiuhuachuan/NSFW-detector',
10
  num_labels=2,
11
  finetuning_task='text classification')
12
- tokenizer = BertTokenizer.from_pretrained('qiuhuachuan/NSFW-detector',
13
  use_fast=False,
14
  never_split=['[user]', '[bot]'])
15
  tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
@@ -22,7 +22,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
22
  self.num_labels = config.num_labels
23
  self.config = config
24
 
25
- self.bert = BertModel.from_pretrained('bert-base-cased')
26
  classifier_dropout = (config.classifier_dropout
27
  if config.classifier_dropout is not None else
28
  config.hidden_dropout_prob)
@@ -71,19 +71,47 @@ model.load_state_dict(torch.load('./NSFW-detector/pytorch_model.bin'))
71
  model.cuda()
72
  model.eval()
73
 
74
- text = '''I'm open to exploring a variety of toys, including vibrators, wands, and clamps. I also love exploring different kinds of restraints and bondage equipment. I'm open to trying out different kinds of toys and exploring different levels of intensity.'''
75
- result = tokenizer.encode_plus(text=text,
76
- padding='max_length',
77
- max_length=512,
78
- truncation=True,
79
- add_special_tokens=True,
80
- return_token_type_ids=True,
81
- return_tensors='pt')
82
- result = result.to('cuda')
83
-
84
- with torch.no_grad():
85
- logits = model(**result)
86
- predictions = logits.argmax(dim=-1)
87
- pred_label_idx = predictions.item()
88
- pred_label = label_mapping[pred_label_idx]
89
- print('predicted label is:', pred_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  label_mapping = {0: 'NSFW', 1: 'SFW'}
8
 
9
+ config = BertConfig.from_pretrained('./NSFW-detector',
10
  num_labels=2,
11
  finetuning_task='text classification')
12
+ tokenizer = BertTokenizer.from_pretrained('./NSFW-detector',
13
  use_fast=False,
14
  never_split=['[user]', '[bot]'])
15
  tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
 
22
  self.num_labels = config.num_labels
23
  self.config = config
24
 
25
+ self.bert = BertModel.from_pretrained('./NSFW-detector')
26
  classifier_dropout = (config.classifier_dropout
27
  if config.classifier_dropout is not None else
28
  config.hidden_dropout_prob)
 
71
  model.cuda()
72
  model.eval()
73
 
74
+ # The following object contains two types of input:
75
+ # a single utterance and a single-turn dialogue
76
+ to_predict_items = [{
77
+ 'text':
78
+ '''Give some movie recommendations to get women in the mood for sex'''
79
+ }, {
80
+ 'text': 'I break through walls to find more trials'
81
+ }, {
82
+ 'history': {
83
+ 'user':
84
+ 'Give some movie recommendations to get women in the mood for sex',
85
+ 'bot':
86
+ '''I apologize, but I cannot assist in creating or providing information related to NSFW content or any tools for such purposes. If you have any other non-explicit requests or need assistance with different topics, please feel free to ask, and I'll be happy to help.'''
87
+ }
88
+ }, {
89
+ 'history': {
90
+ 'user':
91
+ 'Give some movie recommendations to get women in the mood for sex',
92
+ 'bot': '''Sure.'''
93
+ }
94
+ }]
95
+
96
+ for item in to_predict_items:
97
+ if 'history' in item:
98
+ text = '[user] ' + item['history'][
99
+ 'user'] + ' [SEP] ' + '[bot] ' + item['history']['bot']
100
+ else:
101
+ text = item['text']
102
+ result = tokenizer.encode_plus(text=text,
103
+ padding='max_length',
104
+ max_length=512,
105
+ truncation=True,
106
+ add_special_tokens=True,
107
+ return_token_type_ids=True,
108
+ return_tensors='pt')
109
+ result = result.to('cuda')
110
+
111
+ with torch.no_grad():
112
+ logits = model(**result)
113
+ predictions = logits.argmax(dim=-1)
114
+ pred_label_idx = predictions.item()
115
+ pred_label = label_mapping[pred_label_idx]
116
+ print('text:', text)
117
+ print('predicted label is:', pred_label)