Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -34,23 +34,23 @@ DSPY_PREFIX_URL = "luna-code/dspy-codegen-350M-mono-prefix"
|
|
34 |
CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
|
35 |
|
36 |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
|
37 |
-
basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL)
|
38 |
|
39 |
-
sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL)
|
40 |
-
sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL)
|
41 |
-
megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL)
|
42 |
-
main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL)
|
43 |
|
44 |
-
sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL)
|
45 |
-
sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL)
|
46 |
-
megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL)
|
47 |
-
main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL)
|
48 |
-
main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL)
|
49 |
|
50 |
-
langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL)
|
51 |
-
llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL)
|
52 |
-
dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL)
|
53 |
-
cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL)
|
54 |
|
55 |
# basemodel = ""
|
56 |
# sql_prefix = ""
|
@@ -147,8 +147,7 @@ theme = gr.themes.Monochrome(
|
|
147 |
)
|
148 |
|
149 |
def stream(model, code, generate_kwargs):
|
150 |
-
|
151 |
-
input_ids = tokenizer(code, return_tensors="pt").to("cuda")
|
152 |
generated_ids = model.generate(**input_ids, **generate_kwargs)
|
153 |
return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
154 |
|
@@ -183,6 +182,8 @@ def generate(
|
|
183 |
output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
|
184 |
elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
|
185 |
output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
|
|
|
|
|
186 |
else:
|
187 |
output = ""
|
188 |
|
@@ -241,7 +242,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
|
|
241 |
gr.Markdown(description)
|
242 |
with gr.Row():
|
243 |
library = gr.Dropdown(
|
244 |
-
["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "
|
245 |
value="LangChain",
|
246 |
label="Library",
|
247 |
info="Choose a library from the list",
|
|
|
34 |
CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
|
35 |
|
36 |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
|
37 |
+
basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL, device_map="auto")
|
38 |
|
39 |
+
sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL, device_map="auto")
|
40 |
+
sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL, device_map="auto")
|
41 |
+
megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL, device_map="auto")
|
42 |
+
main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL, device_map="auto")
|
43 |
|
44 |
+
sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL, device_map="auto")
|
45 |
+
sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL, device_map="auto")
|
46 |
+
megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL, device_map="auto")
|
47 |
+
main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL, device_map="auto")
|
48 |
+
main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL, device_map="auto")
|
49 |
|
50 |
+
langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL, device_map="auto")
|
51 |
+
llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL, device_map="auto")
|
52 |
+
dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL, device_map="auto")
|
53 |
+
cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL, device_map="auto")
|
54 |
|
55 |
# basemodel = ""
|
56 |
# sql_prefix = ""
|
|
|
147 |
)
|
148 |
|
149 |
def stream(model, code, generate_kwargs):
|
150 |
+
input_ids = tokenizer(code, return_tensors="pt").to(device)
|
|
|
151 |
generated_ids = model.generate(**input_ids, **generate_kwargs)
|
152 |
return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
153 |
|
|
|
182 |
output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
|
183 |
elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
|
184 |
output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
|
185 |
+
elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
|
186 |
+
output = stream(model_map["CS Evo Prefix"], prompt, generate_kwargs)
|
187 |
else:
|
188 |
output = ""
|
189 |
|
|
|
242 |
gr.Markdown(description)
|
243 |
with gr.Row():
|
244 |
library = gr.Dropdown(
|
245 |
+
["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "DSPy"],
|
246 |
value="LangChain",
|
247 |
label="Library",
|
248 |
info="Choose a library from the list",
|