bambadij commited on
Commit
7949d6d
·
verified ·
1 Parent(s): 51ae396

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -62,9 +62,27 @@ async def predict(request: PredictionRequest):
62
  prompt = default_prompt + "\n\n" + request.text
63
  else:
64
  prompt = default_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Tokenize l'entrée et créez un attention mask
67
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
68
  input_ids = inputs.input_ids.to(model.device)
69
  attention_mask = inputs.attention_mask.to(model.device)
70
 
@@ -72,9 +90,10 @@ async def predict(request: PredictionRequest):
72
  outputs = model.generate(
73
  input_ids,
74
  attention_mask=attention_mask,
75
- max_length=3000,
76
  do_sample=True
77
  )
 
78
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
 
80
  return {"generated_text": generated_text}
 
62
  prompt = default_prompt + "\n\n" + request.text
63
  else:
64
  prompt = default_prompt
65
+ # Assurez-vous que le pad_token est défini
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ # Définir une longueur maximale arbitraire pour la tokenization
70
+ max_length = 1024 # Vous pouvez ajuster cette valeur selon vos besoins
71
+
72
+ # Tokenize l'entrée sans troncation automatique
73
+ inputs = tokenizer(
74
+ prompt,
75
+ return_tensors="pt",
76
+ padding=True,
77
+ truncation=False,
78
+ max_length=None # Pas de longueur maximale pour la tokenization
79
+ )
80
+
81
+ # Tronquer manuellement si nécessaire
82
+ if inputs.input_ids.shape[1] > max_length:
83
+ inputs.input_ids = inputs.input_ids[:, :max_length]
84
+ inputs.attention_mask = inputs.attention_mask[:, :max_length]
85
 
 
 
86
  input_ids = inputs.input_ids.to(model.device)
87
  attention_mask = inputs.attention_mask.to(model.device)
88
 
 
90
  outputs = model.generate(
91
  input_ids,
92
  attention_mask=attention_mask,
93
+ max_length=3000, # Longueur maximale pour la génération
94
  do_sample=True
95
  )
96
+
97
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
 
99
  return {"generated_text": generated_text}