AleRive commited on
Commit
8e62ac4
·
verified ·
1 Parent(s): 4d867ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -14,7 +14,7 @@ app = FastAPI()
14
  os.makedirs("/tmp/huggingface", exist_ok=True)
15
 
16
  # Carica il modello Hugging Face
17
- model_name = "microsoft/DialoGPT-small"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  model = AutoModelForCausalLM.from_pretrained(model_name)
20
 
@@ -31,7 +31,12 @@ async def chat(request: Request):
31
 
32
  # Tokenizzazione e generazione della risposta
33
  inputs = tokenizer(prompt, return_tensors="pt")
34
- outputs = model.generate(inputs["input_ids"], max_length=50)
 
 
 
 
 
35
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
 
37
  return JSONResponse({"response": response})
 
14
  os.makedirs("/tmp/huggingface", exist_ok=True)
15
 
16
  # Carica il modello Hugging Face
17
+ model_name = "microsoft/DialoGPT-medium"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  model = AutoModelForCausalLM.from_pretrained(model_name)
20
 
 
31
 
32
  # Tokenizzazione e generazione della risposta
33
  inputs = tokenizer(prompt, return_tensors="pt")
34
+ outputs = model.generate(
35
+ inputs["input_ids"],
36
+ max_length=50,
37
+ pad_token_id=tokenizer.eos_token_id, # Aggiunto per evitare warning
38
+ attention_mask=inputs["attention_mask"] # Aggiunto per maggiore stabilità
39
+ )
40
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
  return JSONResponse({"response": response})