terryyz commited on
Commit
6854c08
1 Parent(s): 56eb7ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -146,12 +146,6 @@ theme = gr.themes.Monochrome(
146
  ],
147
  )
148
 
149
- def stream(model, code, generate_kwargs):
150
- input_ids = tokenizer(code, return_tensors="pt").to(device)
151
- # generated_ids = model.generate(**input_ids, **generate_kwargs)
152
- generated_ids = model.generate(**input_ids)
153
- return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
154
-
155
  @spaces.GPU
156
  def generate(
157
  prompt, temperature=0.6, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, library="LangChain", method="Prefix"
@@ -174,20 +168,25 @@ def generate(
174
  if method == "Base":
175
  output = stream(basemodel, prompt, generate_kwargs)
176
  elif method == "Prefix":
177
- output = stream(model_map[library + " Prefix"], prompt, generate_kwargs)
178
  elif method == "Evo Prefix" and library in ["SQLModel", "SfePy", "MegEngine"]:
179
- output = stream(model_map["Main Evo Prefix"], prompt, generate_kwargs)
180
  elif method == "FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
181
- output = stream(model_map[library + " FFT"], prompt, generate_kwargs)
182
  elif method == "Evo FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
183
- output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
184
  elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
185
- output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
186
  elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
187
- output = stream(model_map["CS Evo Prefix"], prompt, generate_kwargs)
188
  else:
189
  output = ""
190
- return output
 
 
 
 
 
191
 
192
 
193
  examples = [
 
146
  ],
147
  )
148
 
 
 
 
 
 
 
149
  @spaces.GPU
150
  def generate(
151
  prompt, temperature=0.6, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, library="LangChain", method="Prefix"
 
168
  if method == "Base":
169
  output = stream(basemodel, prompt, generate_kwargs)
170
  elif method == "Prefix":
171
+ output = model_map[library + " Prefix"]
172
  elif method == "Evo Prefix" and library in ["SQLModel", "SfePy", "MegEngine"]:
173
+ output = model_map["Main Evo Prefix"]
174
  elif method == "FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
175
+ output = model_map[library + " FFT"]
176
  elif method == "Evo FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
177
+ output = model_map["Main Evo FFT"]
178
  elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
179
+ output = model_map["Main FD FFT"]
180
  elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
181
+ model = model_map["CS Evo Prefix"]
182
  else:
183
  output = ""
184
+ model.to(device)
185
+ input_ids = tokenizer(code, return_tensors="pt").to(device)
186
+ # generated_ids = model.generate(**input_ids, **generate_kwargs)
187
+ generated_ids = model.generate(**input_ids)
188
+
189
+ return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
190
 
191
 
192
  examples = [