Awlly commited on
Commit
28e3315
·
verified ·
1 Parent(s): a7f98a4

Update app_models/gpt_MODEL.py

Browse files
Files changed (1) hide show
  1. app_models/gpt_MODEL.py +2 -2
app_models/gpt_MODEL.py CHANGED
@@ -10,7 +10,7 @@ model = GPT2LMHeadModel.from_pretrained(model_path)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
 
13
- def generate_text(prompt_text, length, temperature):
14
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
15
  encoded_prompt = encoded_prompt.to(device)
16
 
@@ -22,7 +22,7 @@ def generate_text(prompt_text, length, temperature):
22
  top_p=0.9,
23
  repetition_penalty=1.2,
24
  do_sample=True,
25
- num_return_sequences=1,
26
  )
27
 
28
  # Decode the generated text
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
 
13
+ def generate_text(prompt_text, length, temperature, beams):
14
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
15
  encoded_prompt = encoded_prompt.to(device)
16
 
 
22
  top_p=0.9,
23
  repetition_penalty=1.2,
24
  do_sample=True,
25
+ num_return_sequences=beams,
26
  )
27
 
28
  # Decode the generated text