s3nh commited on
Commit
9ffa5bd
1 Parent(s): 8acee1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -39
app.py CHANGED
@@ -63,7 +63,7 @@ def evaluate(instruction, input, model, tokenizer):
63
  result.append( output.split("### Response:")[1].strip())
64
  return ' '.join(el for el in result)
65
 
66
- def inference(model_name, text, input):
67
  model = load_model(model_name)
68
  tokenizer = load_tokenizer(model_name)
69
  output = evaluate(instruction = text, input = input, model = model, tokenizer = tokenizer)
@@ -72,43 +72,7 @@ def inference(model_name, text, input):
72
  def choose_model(name):
73
  return load_model(name), load_tokenizer(name)
74
 
75
- with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
76
- temperature = gr.Slider(
77
- label="Temperature",
78
- value=0.7,
79
- minimum=0.0,
80
- maximum=1.0,
81
- step=0.1,
82
- interactive=True,
83
- info="Higher values produce more diverse outputs",
84
- )
85
- top_p = gr.Slider(
86
- label="Top-p (nucleus sampling)",
87
- value=0.9,
88
- minimum=0.0,
89
- maximum=1,
90
- step=0.05,
91
- interactive=True,
92
- info="Higher values sample more low-probability tokens",
93
- )
94
- max_new_tokens = gr.Slider(
95
- label="Max new tokens",
96
- value=1024,
97
- minimum=0,
98
- maximum=2048,
99
- step=4,
100
- interactive=True,
101
- info="The maximum numbers of new tokens",
102
- )
103
- repetition_penalty = gr.Slider(
104
- label="Repetition Penalty",
105
- value=1.2,
106
- minimum=0.0,
107
- maximum=10,
108
- step=0.1,
109
- interactive=True,
110
- info="The parameter for repetition penalty. 1.0 means no penalty.",
111
- )
112
 
113
  io = gr.Interface(
114
  inference,
@@ -128,7 +92,7 @@ io = gr.Interface(
128
  #"stablelm-base-alpha-3b-Lora-polish",
129
  #"dolly-v2-3b-Lora-polish",
130
  #"LaMini-GPT-1.5B-Lora-polish"],
131
- ]),
132
  gr.Textbox(
133
  lines = 3,
134
  max_lines = 10,
@@ -142,6 +106,42 @@ io = gr.Interface(
142
  placeholder = "Add context here",
143
  interactive = True,
144
  show_label = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  )],
146
  outputs = [gr.Textbox(lines = 1, label = 'Pythia410m', interactive = False)],
147
  cache_examples = False,
 
63
  result.append( output.split("### Response:")[1].strip())
64
  return ' '.join(el for el in result)
65
 
66
+ def inference(model_name, text, input, temperature, top_p, num_beams):
67
  model = load_model(model_name)
68
  tokenizer = load_tokenizer(model_name)
69
  output = evaluate(instruction = text, input = input, model = model, tokenizer = tokenizer)
 
72
  def choose_model(name):
73
  return load_model(name), load_tokenizer(name)
74
 
75
+ with
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  io = gr.Interface(
78
  inference,
 
92
  #"stablelm-base-alpha-3b-Lora-polish",
93
  #"dolly-v2-3b-Lora-polish",
94
  #"LaMini-GPT-1.5B-Lora-polish"],
95
+ ],
96
  gr.Textbox(
97
  lines = 3,
98
  max_lines = 10,
 
106
  placeholder = "Add context here",
107
  interactive = True,
108
  show_label = False
109
+ ),
110
+ gr.Slider(
111
+ label="Temperature",
112
+ value=0.7,
113
+ minimum=0.0,
114
+ maximum=1.0,
115
+ step=0.1,
116
+ interactive=True,
117
+ info="Higher values produce more diverse outputs",
118
+ ),
119
+ gr.Slider(
120
+ label="Top-p (nucleus sampling)",
121
+ value=0.9,
122
+ minimum=0.0,
123
+ maximum=1,
124
+ step=0.05,
125
+ interactive=True,
126
+ info="Higher values sample more low-probability tokens",
127
+ ),
128
+ gr.Slider(
129
+ label="Max new tokens",
130
+ value=1024,
131
+ minimum=0,
132
+ maximum=2048,
133
+ step=4,
134
+ interactive=True,
135
+ info="The maximum numbers of new tokens",
136
+ ),
137
+ gr.Slider(
138
+ label="Number of beams",
139
+ value=2,
140
+ minimum=0.0,
141
+ maximum=5.0,
142
+ step=1.0,
143
+ interactive=True,
144
+ info="The parameter for repetition penalty. 1.0 means no penalty.",
145
  )],
146
  outputs = [gr.Textbox(lines = 1, label = 'Pythia410m', interactive = False)],
147
  cache_examples = False,