ierhon commited on
Commit
8d2e061
1 Parent(s): d989b70

Add sliders

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from keras.models import Model
5
  from keras.saving import load_model
6
  from keras.layers import *
 
7
  from keras.preprocessing.text import Tokenizer
8
  import os
9
  import hashlib
@@ -18,7 +19,7 @@ maxshift = 4
18
  def hash_str(data: str):
19
  return hashlib.md5(data.encode('utf-8')).hexdigest()
20
 
21
- def train(message: str, data: str):
22
  if "→" not in data or "\n" not in data:
23
  return "Dataset example:\nquestion→answer\nquestion→answer\netc."
24
  dset, responses = todset(data)
@@ -66,7 +67,7 @@ def train(message: str, data: str):
66
  X = np.array(X)
67
  y = np.array(y)
68
 
69
- model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy",])
70
 
71
  model.fit(X, y, epochs=16, batch_size=8, workers=4, use_multiprocessing=True)
72
  model.save(f"cache/{data_hash}")
@@ -75,5 +76,11 @@ def train(message: str, data: str):
75
  keras.backend.clear_session()
76
  return responses[np.argmax(prediction)]
77
 
78
- iface = gr.Interface(fn=train, inputs=["text", "text"], outputs="text")
 
 
 
 
 
 
79
  iface.launch()
 
4
  from keras.models import Model
5
  from keras.saving import load_model
6
  from keras.layers import *
7
+ from keras.optimizers import RMSProp
8
  from keras.preprocessing.text import Tokenizer
9
  import os
10
  import hashlib
 
19
  def hash_str(data: str):
20
  return hashlib.md5(data.encode('utf-8')).hexdigest()
21
 
22
+ def train(message: str, epochs: int, learning_rate: float, emb_size: int, inp_len: int, data: str):
23
  if "→" not in data or "\n" not in data:
24
  return "Dataset example:\nquestion→answer\nquestion→answer\netc."
25
  dset, responses = todset(data)
 
67
  X = np.array(X)
68
  y = np.array(y)
69
 
70
+ model.compile(optimizer=RMSProp(learning_rate=learning_rate), loss="sparse_categorical_crossentropy", metrics=["accuracy",])
71
 
72
  model.fit(X, y, epochs=16, batch_size=8, workers=4, use_multiprocessing=True)
73
  model.save(f"cache/{data_hash}")
 
76
  keras.backend.clear_session()
77
  return responses[np.argmax(prediction)]
78
 
79
+ iface = gr.Interface(fn=train, inputs=["text",
80
+ gr.inputs.Slider(1, 64, 32, step=1, label="Epochs"),
81
+ gr.inputs.Slider(0.00000001, 0.1, 0.001, step=0.00000001, label="Learning rate"),
82
+ gr.inputs.Slider(1, 256, 100, step=1, label="Embedding size"),
83
+ gr.inputs.Slider(1,128, step=1, label="Input Length"),
84
+ "text"],
85
+ outputs="text")
86
  iface.launch()