whyumesh commited on
Commit
ea899f2
1 Parent(s): 579c265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -10,29 +10,27 @@ def load_model():
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
  torch_dtype=torch.float16,
13
- device_map="auto"
 
14
  )
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  return model, tokenizer
17
 
18
  model, tokenizer = load_model()
19
 
20
- @spaces.GPU(duration=60) # Adjust duration based on your needs
21
  def fix_code(input_code):
22
- # Prepare the prompt
23
  messages = [
24
  {"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
25
  {"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
26
  ]
27
 
28
- # Apply chat template
29
  text = tokenizer.apply_chat_template(
30
  messages,
31
  tokenize=False,
32
  add_generation_prompt=True
33
  )
34
 
35
- # Tokenize and generate
36
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
37
  generated_ids = model.generate(
38
  **model_inputs,
@@ -41,7 +39,6 @@ def fix_code(input_code):
41
  top_p=0.95,
42
  )
43
 
44
- # Decode only the new tokens
45
  generated_ids = [
46
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
47
  ]
@@ -49,7 +46,6 @@ def fix_code(input_code):
49
 
50
  return response
51
 
52
- # Create Gradio interface
53
  iface = gr.Interface(
54
  fn=fix_code,
55
  inputs=gr.Code(
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
  torch_dtype=torch.float16,
13
+ device_map="auto",
14
+ low_cpu_mem_usage=True # This requires Accelerate
15
  )
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  return model, tokenizer
18
 
19
  model, tokenizer = load_model()
20
 
21
+ @spaces.GPU(duration=60)
22
  def fix_code(input_code):
 
23
  messages = [
24
  {"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
25
  {"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
26
  ]
27
 
 
28
  text = tokenizer.apply_chat_template(
29
  messages,
30
  tokenize=False,
31
  add_generation_prompt=True
32
  )
33
 
 
34
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
35
  generated_ids = model.generate(
36
  **model_inputs,
 
39
  top_p=0.95,
40
  )
41
 
 
42
  generated_ids = [
43
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
44
  ]
 
46
 
47
  return response
48
 
 
49
  iface = gr.Interface(
50
  fn=fix_code,
51
  inputs=gr.Code(