SpiceyToad commited on
Commit
f1cf6cf
1 Parent(s): 3ec8d1c

Upload app.py

Browse files

optimize app.py

Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -23,6 +23,10 @@ model = AutoModelForCausalLM.from_pretrained(
23
  token=HF_API_TOKEN
24
  )
25
 
 
 
 
 
26
  @app.post("/generate")
27
  async def generate_text(request: Request):
28
  data = await request.json()
@@ -31,15 +35,15 @@ async def generate_text(request: Request):
31
 
32
  # Tokenize with padding and attention mask
33
  inputs = tokenizer(
34
- prompt,
35
- return_tensors="pt",
36
- padding=True,
37
  truncation=True
38
  ).to(model.device)
39
 
40
  outputs = model.generate(
41
- inputs["input_ids"],
42
- attention_mask=inputs["attention_mask"],
43
  max_length=max_length
44
  )
45
 
 
23
  token=HF_API_TOKEN
24
  )
25
 
26
+ # Ensure tokenizer has a padding token
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token # Use the EOS token as the padding token
29
+
30
  @app.post("/generate")
31
  async def generate_text(request: Request):
32
  data = await request.json()
 
35
 
36
  # Tokenize with padding and attention mask
37
  inputs = tokenizer(
38
+ prompt,
39
+ return_tensors="pt",
40
+ padding=True,
41
  truncation=True
42
  ).to(model.device)
43
 
44
  outputs = model.generate(
45
+ inputs["input_ids"],
46
+ attention_mask=inputs["attention_mask"],
47
  max_length=max_length
48
  )
49