loubnabnl HF staff commited on
Commit
4da09e3
1 Parent(s): 0aa1e7c

add python extension to prompt

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -15,8 +15,11 @@ model = AutoModelForCausalLM.from_pretrained("facebook/incoder-6B", low_cpu_mem_
15
 
16
  MAX_LENGTH = 2048
17
  BOS = "<|endoftext|>"
 
 
18
  def generate(gen_prompt, max_tokens, temperature=0.6, seed=42):
19
  set_seed(seed)
 
20
  input_ids = tokenizer(gen_prompt, return_tensors="pt").input_ids
21
  current_length = input_ids.flatten().size(0)
22
  max_length = max_tokens + current_length
@@ -26,6 +29,7 @@ def generate(gen_prompt, max_tokens, temperature=0.6, seed=42):
26
  generated_text = tokenizer.decode(output.flatten())
27
  if generated_text.startswith(BOS):
28
  generated_text = generated_text[len(BOS):]
 
29
  return generated_text
30
 
31
  iface = gr.Interface(
 
15
 
16
  MAX_LENGTH = 2048
17
  BOS = "<|endoftext|>"
18
+ EXTENSION = "<| file ext=.py |>\n"
19
+
20
  def generate(gen_prompt, max_tokens, temperature=0.6, seed=42):
21
  set_seed(seed)
22
+ gen_prompt = EXTENSION + gen_prompt
23
  input_ids = tokenizer(gen_prompt, return_tensors="pt").input_ids
24
  current_length = input_ids.flatten().size(0)
25
  max_length = max_tokens + current_length
 
29
  generated_text = tokenizer.decode(output.flatten())
30
  if generated_text.startswith(BOS):
31
  generated_text = generated_text[len(BOS):]
32
+ generated_text = generated_text[len(EXTENSION):]
33
  return generated_text
34
 
35
  iface = gr.Interface(