Update app.py
Browse files
app.py
CHANGED
@@ -10,29 +10,27 @@ def load_model():
|
|
10 |
model = AutoModelForCausalLM.from_pretrained(
|
11 |
model_name,
|
12 |
torch_dtype=torch.float16,
|
13 |
-
device_map="auto"
|
|
|
14 |
)
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
16 |
return model, tokenizer
|
17 |
|
18 |
model, tokenizer = load_model()
|
19 |
|
20 |
-
@spaces.GPU(duration=60)
|
21 |
def fix_code(input_code):
|
22 |
-
# Prepare the prompt
|
23 |
messages = [
|
24 |
{"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
|
25 |
{"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
|
26 |
]
|
27 |
|
28 |
-
# Apply chat template
|
29 |
text = tokenizer.apply_chat_template(
|
30 |
messages,
|
31 |
tokenize=False,
|
32 |
add_generation_prompt=True
|
33 |
)
|
34 |
|
35 |
-
# Tokenize and generate
|
36 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
37 |
generated_ids = model.generate(
|
38 |
**model_inputs,
|
@@ -41,7 +39,6 @@ def fix_code(input_code):
|
|
41 |
top_p=0.95,
|
42 |
)
|
43 |
|
44 |
-
# Decode only the new tokens
|
45 |
generated_ids = [
|
46 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
47 |
]
|
@@ -49,7 +46,6 @@ def fix_code(input_code):
|
|
49 |
|
50 |
return response
|
51 |
|
52 |
-
# Create Gradio interface
|
53 |
iface = gr.Interface(
|
54 |
fn=fix_code,
|
55 |
inputs=gr.Code(
|
|
|
10 |
model = AutoModelForCausalLM.from_pretrained(
|
11 |
model_name,
|
12 |
torch_dtype=torch.float16,
|
13 |
+
device_map="auto",
|
14 |
+
low_cpu_mem_usage=True # This requires Accelerate
|
15 |
)
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
return model, tokenizer
|
18 |
|
19 |
model, tokenizer = load_model()
|
20 |
|
21 |
+
@spaces.GPU(duration=60)
|
22 |
def fix_code(input_code):
|
|
|
23 |
messages = [
|
24 |
{"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
|
25 |
{"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
|
26 |
]
|
27 |
|
|
|
28 |
text = tokenizer.apply_chat_template(
|
29 |
messages,
|
30 |
tokenize=False,
|
31 |
add_generation_prompt=True
|
32 |
)
|
33 |
|
|
|
34 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
35 |
generated_ids = model.generate(
|
36 |
**model_inputs,
|
|
|
39 |
top_p=0.95,
|
40 |
)
|
41 |
|
|
|
42 |
generated_ids = [
|
43 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
44 |
]
|
|
|
46 |
|
47 |
return response
|
48 |
|
|
|
49 |
iface = gr.Interface(
|
50 |
fn=fix_code,
|
51 |
inputs=gr.Code(
|