Jinpkk commited on
Commit
a7b7788
1 Parent(s): b0f7366

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -12,21 +12,30 @@ model_name_or_path = "Jinpkk/codeparrot-ds"
12
  tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
13
  model = TFGPT2LMHeadModel.from_pretrained(model_name_or_path)
14
 
 
15
  def generate_response(input_text):
16
  input_ids = tokenizer.encode(input_text, return_tensors='tf')
17
- beam_output = model.generate(input_ids, max_length=150, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
18
- output = tokenizer.decode(beam_output[0], skip_special_tokens=True)
19
-
20
- # Format the output as needed
21
- formatted_output = output.replace(input_text, "").strip() # Removing input text from output
22
- generated_text = formatted_output.split('[PAD]')[0].strip()
23
- return formatted_output
 
 
 
 
 
 
24
 
25
 
26
  interface = gr.Interface(
27
  fn=generate_response,
 
 
28
  description=description,
29
- title = title,
30
  examples=examples
31
  )
32
 
 
12
  tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
13
  model = TFGPT2LMHeadModel.from_pretrained(model_name_or_path)
14
 
15
+
16
  def generate_response(input_text):
17
  input_ids = tokenizer.encode(input_text, return_tensors='tf')
18
+ beam_output = model.generate(
19
+ input_ids,
20
+ max_length=128,
21
+ num_beams=5,
22
+ no_repeat_ngram_size=2,
23
+ early_stopping=True,
24
+ pad_token_id=tokenizer.eos_token_id
25
+ )
26
+ generated_text = tokenizer.decode(beam_output[0], skip_special_tokens=True)
27
+ generated_text = generated_text.split('[PAD]')[0].strip()
28
+ generated_text = generated_text.replace(input_text.strip(), '').strip()
29
+ return generated_text
30
+
31
 
32
 
33
  interface = gr.Interface(
34
  fn=generate_response,
35
+ inputs=gr.inputs.Textbox(lines=5, placeholder="Enter your text..."),
36
+ outputs=gr.outputs.Textbox(),
37
  description=description,
38
+ title=title,
39
  examples=examples
40
  )
41