ValdeciRodrigues commited on
Commit
a5a92fd
·
verified ·
1 Parent(s): 9b90156

Update logic/generator.py

Browse files
Files changed (1) hide show
  1. logic/generator.py +20 -3
logic/generator.py CHANGED
@@ -2,7 +2,10 @@
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
 
5
  model_id = "stabilityai/stable-code-3b"
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForCausalLM.from_pretrained(
8
  model_id,
@@ -12,17 +15,31 @@ model = AutoModelForCausalLM.from_pretrained(
12
 
13
  def generate_code(prompt):
14
  try:
 
15
  formatted_prompt = f"# Escreva um código Python que faça o seguinte:\n# {prompt}\n"
16
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
17
  outputs = model.generate(
18
  **inputs,
19
- max_new_tokens=256,
20
  do_sample=True,
21
  temperature=0.3,
22
  top_k=50,
23
- top_p=0.95
 
24
  )
 
 
25
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
  return result.strip()
 
27
  except Exception as e:
28
  return f"Erro ao gerar código: {str(e)}"
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Modelo poderoso, mas exige cuidado com limite de tokens
6
  model_id = "stabilityai/stable-code-3b"
7
+
8
+ # Carregamento otimizado
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
 
15
 
16
  def generate_code(prompt):
17
  try:
18
+ # Instrução formatada para guiar o modelo
19
  formatted_prompt = f"# Escreva um código Python que faça o seguinte:\n# {prompt}\n"
20
+
21
+ # Tokenização com truncamento seguro para evitar overflow
22
+ inputs = tokenizer(
23
+ formatted_prompt,
24
+ return_tensors="pt",
25
+ truncation=True,
26
+ max_length=512 # entrada limitada para evitar travamento
27
+ ).to(model.device)
28
+
29
+ # Geração com finalização forçada via EOS token
30
  outputs = model.generate(
31
  **inputs,
32
+ max_new_tokens=256, # reduzido para caber nos limites de GPU
33
  do_sample=True,
34
  temperature=0.3,
35
  top_k=50,
36
+ top_p=0.95,
37
+ eos_token_id=tokenizer.eos_token_id # 🚨 Essencial para evitar loop eterno
38
  )
39
+
40
+ # Decodifica e retorna o texto limpo
41
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
  return result.strip()
43
+
44
  except Exception as e:
45
  return f"Erro ao gerar código: {str(e)}"