Update app.py
Browse files
app.py
CHANGED
@@ -62,9 +62,27 @@ async def predict(request: PredictionRequest):
|
|
62 |
prompt = default_prompt + "\n\n" + request.text
|
63 |
else:
|
64 |
prompt = default_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
# Tokenize l'entrée et créez un attention mask
|
67 |
-
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
68 |
input_ids = inputs.input_ids.to(model.device)
|
69 |
attention_mask = inputs.attention_mask.to(model.device)
|
70 |
|
@@ -72,9 +90,10 @@ async def predict(request: PredictionRequest):
|
|
72 |
outputs = model.generate(
|
73 |
input_ids,
|
74 |
attention_mask=attention_mask,
|
75 |
-
max_length=3000,
|
76 |
do_sample=True
|
77 |
)
|
|
|
78 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
79 |
|
80 |
return {"generated_text": generated_text}
|
|
|
62 |
prompt = default_prompt + "\n\n" + request.text
|
63 |
else:
|
64 |
prompt = default_prompt
|
65 |
+
# Assurez-vous que le pad_token est défini
|
66 |
+
if tokenizer.pad_token is None:
|
67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
68 |
+
|
69 |
+
# Définir une longueur maximale arbitraire pour la tokenization
|
70 |
+
max_length = 1024 # Vous pouvez ajuster cette valeur selon vos besoins
|
71 |
+
|
72 |
+
# Tokenize l'entrée sans troncation automatique
|
73 |
+
inputs = tokenizer(
|
74 |
+
prompt,
|
75 |
+
return_tensors="pt",
|
76 |
+
padding=True,
|
77 |
+
truncation=False,
|
78 |
+
max_length=None # Pas de longueur maximale pour la tokenization
|
79 |
+
)
|
80 |
+
|
81 |
+
# Tronquer manuellement si nécessaire
|
82 |
+
if inputs.input_ids.shape[1] > max_length:
|
83 |
+
inputs.input_ids = inputs.input_ids[:, :max_length]
|
84 |
+
inputs.attention_mask = inputs.attention_mask[:, :max_length]
|
85 |
|
|
|
|
|
86 |
input_ids = inputs.input_ids.to(model.device)
|
87 |
attention_mask = inputs.attention_mask.to(model.device)
|
88 |
|
|
|
90 |
outputs = model.generate(
|
91 |
input_ids,
|
92 |
attention_mask=attention_mask,
|
93 |
+
max_length=3000, # Longueur maximale pour la génération
|
94 |
do_sample=True
|
95 |
)
|
96 |
+
|
97 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
|
99 |
return {"generated_text": generated_text}
|