Update app_models/gpt_MODEL.py
Browse files- app_models/gpt_MODEL.py +2 -2
app_models/gpt_MODEL.py
CHANGED
@@ -10,7 +10,7 @@ model = GPT2LMHeadModel.from_pretrained(model_path)
|
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
model.to(device)
|
12 |
|
13 |
-
def generate_text(prompt_text, length, temperature):
|
14 |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
15 |
encoded_prompt = encoded_prompt.to(device)
|
16 |
|
@@ -22,7 +22,7 @@ def generate_text(prompt_text, length, temperature):
|
|
22 |
top_p=0.9,
|
23 |
repetition_penalty=1.2,
|
24 |
do_sample=True,
|
25 |
-
num_return_sequences=
|
26 |
)
|
27 |
|
28 |
# Decode the generated text
|
|
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
model.to(device)
|
12 |
|
13 |
+
def generate_text(prompt_text, length, temperature, beams):
|
14 |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
15 |
encoded_prompt = encoded_prompt.to(device)
|
16 |
|
|
|
22 |
top_p=0.9,
|
23 |
repetition_penalty=1.2,
|
24 |
do_sample=True,
|
25 |
+
num_return_sequences=beams,
|
26 |
)
|
27 |
|
28 |
# Decode the generated text
|