terryyz commited on
Commit
905d227
1 Parent(s): 3baf16c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -34,23 +34,23 @@ DSPY_PREFIX_URL = "luna-code/dspy-codegen-350M-mono-prefix"
34
  CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
37
- basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL)
38
 
39
- sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL)
40
- sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL)
41
- megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL)
42
- main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL)
43
 
44
- sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL)
45
- sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL)
46
- megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL)
47
- main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL)
48
- main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL)
49
 
50
- langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL)
51
- llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL)
52
- dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL)
53
- cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL)
54
 
55
  # basemodel = ""
56
  # sql_prefix = ""
@@ -147,8 +147,7 @@ theme = gr.themes.Monochrome(
147
  )
148
 
149
  def stream(model, code, generate_kwargs):
150
- model.to(device)
151
- input_ids = tokenizer(code, return_tensors="pt").to("cuda")
152
  generated_ids = model.generate(**input_ids, **generate_kwargs)
153
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
154
 
@@ -183,6 +182,8 @@ def generate(
183
  output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
184
  elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
185
  output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
 
 
186
  else:
187
  output = ""
188
 
@@ -241,7 +242,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
241
  gr.Markdown(description)
242
  with gr.Row():
243
  library = gr.Dropdown(
244
- ["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "DSpy"],
245
  value="LangChain",
246
  label="Library",
247
  info="Choose a library from the list",
 
34
  CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
37
+ basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL, device_map="auto")
38
 
39
+ sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL, device_map="auto")
40
+ sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL, device_map="auto")
41
+ megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL, device_map="auto")
42
+ main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL, device_map="auto")
43
 
44
+ sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL, device_map="auto")
45
+ sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL, device_map="auto")
46
+ megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL, device_map="auto")
47
+ main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL, device_map="auto")
48
+ main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL, device_map="auto")
49
 
50
+ langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL, device_map="auto")
51
+ llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL, device_map="auto")
52
+ dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL, device_map="auto")
53
+ cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL, device_map="auto")
54
 
55
  # basemodel = ""
56
  # sql_prefix = ""
 
147
  )
148
 
149
  def stream(model, code, generate_kwargs):
150
+ input_ids = tokenizer(code, return_tensors="pt").to(device)
 
151
  generated_ids = model.generate(**input_ids, **generate_kwargs)
152
  return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
153
 
 
182
  output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
183
  elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
184
  output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
185
+ elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
186
+ output = stream(model_map["CS Evo Prefix"], prompt, generate_kwargs)
187
  else:
188
  output = ""
189
 
 
242
  gr.Markdown(description)
243
  with gr.Row():
244
  library = gr.Dropdown(
245
+ ["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "DSPy"],
246
  value="LangChain",
247
  label="Library",
248
  info="Choose a library from the list",