Daeyongkwon98 commited on
Commit
32006fa
β€’
1 Parent(s): ed12022

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -18,14 +18,14 @@ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float
18
 
19
  # 생성 μ„€μ • (Gradio UIμ—μ„œ μ œμ–΄ν•  수 μžˆλŠ” λ³€μˆ˜λ“€)
20
  default_generation_config = GenerationConfig(
21
- temperature=0.7,
22
- top_k=50,
23
- top_p=0.95,
24
  do_sample=True,
25
  num_beams=1,
26
  repetition_penalty=1.1,
27
  min_new_tokens=10,
28
- max_new_tokens=512
29
  )
30
 
31
  # 응닡 생성 ν•¨μˆ˜
@@ -47,7 +47,12 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
47
 
48
  # λͺ¨λΈ μž…λ ₯ 생성
49
  inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
50
- response_ids = model.generate(**inputs, generation_config=generation_config)
 
 
 
 
 
51
 
52
  # λͺ¨λΈ 응닡 λ””μ½”λ”©
53
  response_text = tokenizer.decode(response_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
@@ -64,9 +69,9 @@ demo = gr.ChatInterface(
64
  respond,
65
  additional_inputs=[
66
  gr.Textbox(value="You are a friendly Chatbot that recommends music.", label="System message"),
67
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
68
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
69
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
70
  ],
71
  )
72
 
 
18
 
19
  # 생성 μ„€μ • (Gradio UIμ—μ„œ μ œμ–΄ν•  수 μžˆλŠ” λ³€μˆ˜λ“€)
20
  default_generation_config = GenerationConfig(
21
+ temperature=0.1,
22
+ top_k=30,
23
+ top_p=0.5,
24
  do_sample=True,
25
  num_beams=1,
26
  repetition_penalty=1.1,
27
  min_new_tokens=10,
28
+ max_new_tokens=30
29
  )
30
 
31
  # 응닡 생성 ν•¨μˆ˜
 
47
 
48
  # λͺ¨λΈ μž…λ ₯ 생성
49
  inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
50
+ response_ids = model.generate(
51
+ **inputs,
52
+ generation_config=generation_config,
53
+ eos_token_id=tokenizer.eos_token_id, # μ’…λ£Œ 토큰 μ„€μ •
54
+ pad_token_id=tokenizer.eos_token_id # pad_token_id도 μ’…λ£Œ ν† ν°μœΌλ‘œ μ„€μ •
55
+ )
56
 
57
  # λͺ¨λΈ 응닡 λ””μ½”λ”©
58
  response_text = tokenizer.decode(response_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
 
69
  respond,
70
  additional_inputs=[
71
  gr.Textbox(value="You are a friendly Chatbot that recommends music.", label="System message"),
72
+ gr.Slider(minimum=1, maximum=2048, value=30, step=1, label="Max new tokens"),
73
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
74
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Top-p (nucleus sampling)"),
75
  ],
76
  )
77