sun-tana commited on
Commit
59bdd3b
β€’
1 Parent(s): 74ed80e
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -14,8 +14,8 @@ from transformers import TFAutoModel, AutoTokenizer
14
  from sklearn.model_selection import train_test_split
15
 
16
  # load the tokenizer and transformer model
17
- tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base",max_length=60) #xlm-roberta-base bert-base-multilingual-cased
18
- transformer_model = TFAutoModel.from_pretrained("xlm-roberta-base") #philschmid/tiny-bert-sst2-distilled
19
  max_seq_length = 32
20
 
21
  def create_model():
@@ -88,30 +88,30 @@ def predict(text):
88
  predicted_labels = model.predict(test_padded_sequences)
89
 
90
  for i in range(len(test_texts)):
91
- print(test_texts[i])
92
- valid = 1 if predicted_labels[0][i] > 0.5 else 0
93
- is_scene = 1 if predicted_labels[1][i] > 0.5 else 0
94
- has_num = 1 if predicted_labels[2][i] > 0.5 else 0
95
- print(f'is_valid : {valid}')
96
- print(f'is_scene : {is_scene}')
97
- print(f'has_num : {has_num}')
98
-
99
- turn = 1 if predicted_labels[3][i] > 0.5 else 0
100
- print(f'turn_on_off : {turn}')
101
- print(f'device : ΰΉ„ΰΈŸ')
102
-
103
- env_id = np.argmax(predicted_labels[5][i])
104
- env_label = env_decode[env_id]
105
-
106
- hour_id = np.argmax(predicted_labels[6][i])
107
- hour_label = hour_decode[hour_id]
108
-
109
- minute_id = np.argmax(predicted_labels[7][i])
110
- minute_label = minute_decode[minute_id]
111
- print(f'env : {env_label}')
112
- print(f'hour : {hour_label}')
113
- print(f'minute : {minute_label}')
114
- print('----')
115
  return 'hello'
116
 
117
  iface = gr.Interface(
 
14
  from sklearn.model_selection import train_test_split
15
 
16
  # load the tokenizer and transformer model
17
+ tokenizer = AutoTokenizer.from_pretrained("nlptown/flaubert_small_cased_sentiment",max_length=60) #xlm-roberta-base bert-base-multilingual-cased
18
+ transformer_model = TFAutoModel.from_pretrained("nlptown/flaubert_small_cased_sentiment") #philschmid/tiny-bert-sst2-distilled
19
  max_seq_length = 32
20
 
21
  def create_model():
 
88
  predicted_labels = model.predict(test_padded_sequences)
89
 
90
  for i in range(len(test_texts)):
91
+ print(test_texts[i])
92
+ valid = 1 if predicted_labels[0][i] > 0.5 else 0
93
+ is_scene = 1 if predicted_labels[1][i] > 0.5 else 0
94
+ has_num = 1 if predicted_labels[2][i] > 0.5 else 0
95
+ print(f'is_valid : {valid}')
96
+ print(f'is_scene : {is_scene}')
97
+ print(f'has_num : {has_num}')
98
+
99
+ turn = 1 if predicted_labels[3][i] > 0.5 else 0
100
+ print(f'turn_on_off : {turn}')
101
+ print(f'device : ΰΉ„ΰΈŸ')
102
+
103
+ env_id = np.argmax(predicted_labels[5][i])
104
+ env_label = env_decode[env_id]
105
+
106
+ hour_id = np.argmax(predicted_labels[6][i])
107
+ hour_label = hour_decode[hour_id]
108
+
109
+ minute_id = np.argmax(predicted_labels[7][i])
110
+ minute_label = minute_decode[minute_id]
111
+ print(f'env : {env_label}')
112
+ print(f'hour : {hour_label}')
113
+ print(f'minute : {minute_label}')
114
+ print('----')
115
  return 'hello'
116
 
117
  iface = gr.Interface(