ierhon commited on
Commit
74c3b09
1 Parent(s): 09acac3

Optimize with argmax

Browse files
Files changed (1) hide show
  1. app.py +1 -7
app.py CHANGED
@@ -72,14 +72,8 @@ def train(data: str, message: str):
72
  model.save(f"cache/{data_hash}")
73
  tokens = tokenizer.texts_to_sequences([message,])[0]
74
  prediction = model.predict(np.array([(list(tokens)+[0,]*inp_len)[:inp_len],]))[0]
75
- max_o = 0
76
- max_v = 0
77
- for ind, i in enumerate(prediction):
78
- if max_v < i:
79
- max_v = i
80
- max_o = ind
81
  keras.backend.clear_session()
82
- return responses[ind]
83
 
84
  iface = gr.Interface(fn=train, inputs=["text", "text"], outputs="text")
85
  iface.launch()
 
72
  model.save(f"cache/{data_hash}")
73
  tokens = tokenizer.texts_to_sequences([message,])[0]
74
  prediction = model.predict(np.array([(list(tokens)+[0,]*inp_len)[:inp_len],]))[0]
 
 
 
 
 
 
75
  keras.backend.clear_session()
76
+ return responses[np.argmax(prediction)]
77
 
78
  iface = gr.Interface(fn=train, inputs=["text", "text"], outputs="text")
79
  iface.launch()