KYO30 commited on
Commit
67bb651
Β·
verified Β·
1 Parent(s): 88adda9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -4,26 +4,30 @@ import torch
4
  from threading import Thread
5
 
6
  # --- 1. λͺ¨λΈ λ‘œλ“œ (Space의 GPU ν™œμš©) ---
7
- # μš”μ²­ν•˜μ‹  λͺ¨λΈ μ΄λ¦„μž…λ‹ˆλ‹€.
8
- MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-2505"
9
 
10
  print(f"λͺ¨λΈμ„ λ‘œλ”© μ€‘μž…λ‹ˆλ‹€: {MODEL_NAME} (Space GPU μ‚¬μš©)")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_NAME,
14
- dtype=torch.float16, # πŸ’₯ μˆ˜μ •: 'torch_dtype' λŒ€μ‹  'dtype' μ‚¬μš©
15
- device_map="auto"
16
  )
17
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ!")
18
 
19
  # --- 2. 챗봇 응닡 ν•¨μˆ˜ (Gradioκ°€ 이 ν•¨μˆ˜λ₯Ό 호좜) ---
 
 
20
  def predict(message, history):
21
 
22
  # Kanana의 ν”„λ‘¬ν”„νŠΈ ν˜•μ‹: <bos>user\n{prompt}\n<eos>assistant\n
23
  history_prompt = ""
 
24
  for user_msg, assistant_msg in history:
25
  history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
26
 
 
27
  final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
28
 
29
  inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
@@ -31,35 +35,37 @@ def predict(message, history):
31
  # --- μ‹€μ‹œκ°„ 타이핑 효과(슀트리밍)λ₯Ό μœ„ν•œ μ„€μ • ---
32
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
33
 
 
34
  generation_kwargs = dict(
35
- inputs,
36
  streamer=streamer,
37
- max_new_tokens=1024,
38
  eos_token_id=tokenizer.eos_token_id,
39
  pad_token_id=tokenizer.pad_token_id,
40
- temperature=0.7,
41
- do_sample=True
42
  )
43
 
44
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
45
  thread.start()
46
 
 
47
  generated_text = ""
48
  for new_text in streamer:
49
  generated_text += new_text
50
- yield generated_text
51
 
52
  # --- 3. Gradio 챗봇 UI 생성 ---
53
- # πŸ’₯ μˆ˜μ •: 였λ₯˜κ°€ λ°œμƒν•œ 'retry_btn'κ³Ό 'undo_btn' 인자λ₯Ό μ œκ±°ν–ˆμŠ΅λ‹ˆλ‹€.
54
  chatbot_ui = gr.ChatInterface(
55
  fn=predict, # 챗봇이 μ‚¬μš©ν•  ν•¨μˆ˜
56
  title="Kanana 1.5 챗봇 ν…ŒμŠ€νŠΈ πŸ€–",
57
  description=f"{MODEL_NAME} λͺ¨λΈμ„ ν…ŒμŠ€νŠΈν•©λ‹ˆλ‹€.",
58
- theme="soft",
59
- examples=[["ν•œκ΅­μ˜ μˆ˜λ„λŠ” μ–΄λ””μ•Ό?"], ["AI에 λŒ€ν•΄ 3μ€„λ‘œ μš”μ•½ν•΄μ€˜."]],
60
- # retry_btn=None, <-- 이 뢀뢄이 였λ₯˜ 원인 (제거)
61
- # undo_btn="이전 λŒ€ν™” μ‚­μ œ", <-- 이 뢀뢄도 μ΅œμ‹  버전에선 이름이 λ‹€λ₯Ό 수 μžˆμ–΄ 제거
62
- clear_btn="전체 λŒ€ν™” μ΄ˆκΈ°ν™”" # 'clear_btn'은 아직 μœ νš¨ν•©λ‹ˆλ‹€.
63
  )
64
 
65
- # ---
 
 
 
4
  from threading import Thread
5
 
6
  # --- 1. λͺ¨λΈ λ‘œλ“œ (Space의 GPU ν™œμš©) ---
7
+ # 2505 λͺ¨λΈμ€ 아직 μ‘΄μž¬ν•˜μ§€ μ•Šμ•„, ν˜„μž¬ μ΅œμ‹  λͺ¨λΈμΈ 2405둜 μˆ˜μ •ν–ˆμŠ΅λ‹ˆλ‹€.
8
+ MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-2405"
9
 
10
  print(f"λͺ¨λΈμ„ λ‘œλ”© μ€‘μž…λ‹ˆλ‹€: {MODEL_NAME} (Space GPU μ‚¬μš©)")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_NAME,
14
+ torch_dtype=torch.float16, # λ©”λͺ¨λ¦¬ μ ˆμ•½μ„ μœ„ν•΄ 16λΉ„νŠΈ μ‚¬μš©
15
+ device_map="auto" # μ€‘μš”: μ•Œμ•„μ„œ GPU에 ν• λ‹Ή
16
  )
17
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ!")
18
 
19
  # --- 2. 챗봇 응닡 ν•¨μˆ˜ (Gradioκ°€ 이 ν•¨μˆ˜λ₯Ό 호좜) ---
20
+ # message: μ‚¬μš©μžκ°€ μž…λ ₯ν•œ λ©”μ‹œμ§€
21
+ # history: 이전 λŒ€ν™” 기둝 (Gradioκ°€ μžλ™μœΌλ‘œ 관리)
22
  def predict(message, history):
23
 
24
  # Kanana의 ν”„λ‘¬ν”„νŠΈ ν˜•μ‹: <bos>user\n{prompt}\n<eos>assistant\n
25
  history_prompt = ""
26
+ # 이전 λŒ€ν™” 기둝(history)을 Kanana ν”„λ‘¬ν”„νŠΈ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
27
  for user_msg, assistant_msg in history:
28
  history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
29
 
30
+ # ν˜„μž¬ λ©”μ‹œμ§€λ₯Ό ν”„λ‘¬ν”„νŠΈμ— μΆ”κ°€
31
  final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
32
 
33
  inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
 
35
  # --- μ‹€μ‹œκ°„ 타이핑 효과(슀트리밍)λ₯Ό μœ„ν•œ μ„€μ • ---
36
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
 
38
+ # λͺ¨λΈ 생성(generate) μž‘μ—…μ„ 별도 μŠ€λ ˆλ“œμ—μ„œ μ‹€ν–‰
39
  generation_kwargs = dict(
40
+ **inputs, # inputs λ”•μ…”λ„ˆλ¦¬μ˜ λͺ¨λ“  ν‚€-κ°’ μŒμ„ 인자둜 전달
41
  streamer=streamer,
42
+ max_new_tokens=1024, # μ΅œλŒ€ 생성 토큰 수
43
  eos_token_id=tokenizer.eos_token_id,
44
  pad_token_id=tokenizer.pad_token_id,
45
+ temperature=0.7, # μ°½μ˜μ„± 쑰절
46
+ do_sample=True # μƒ˜ν”Œλ§ μ‚¬μš©
47
  )
48
 
49
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
50
  thread.start()
51
 
52
+ # μŠ€νŠΈλ¦¬λ¨Έμ—μ„œ λ‚˜μ˜€λŠ” ν…μŠ€νŠΈλ₯Ό λ°”λ‘œλ°”λ‘œ λ°˜ν™˜ (yield)
53
  generated_text = ""
54
  for new_text in streamer:
55
  generated_text += new_text
56
+ yield generated_text # ν…μŠ€νŠΈλ₯Ό ν•œ κΈ€μžμ”© μ‹€μ‹œκ°„μœΌλ‘œ 보냄
57
 
58
  # --- 3. Gradio 챗봇 UI 생성 ---
59
+ # gr.ChatInterfaceλ₯Ό μ“°λ©΄ UIκ°€ 챗봇 ν˜•νƒœλ‘œ μžλ™ μƒμ„±λ©λ‹ˆλ‹€.
60
  chatbot_ui = gr.ChatInterface(
61
  fn=predict, # 챗봇이 μ‚¬μš©ν•  ν•¨μˆ˜
62
  title="Kanana 1.5 챗봇 ν…ŒμŠ€νŠΈ πŸ€–",
63
  description=f"{MODEL_NAME} λͺ¨λΈμ„ ν…ŒμŠ€νŠΈν•©λ‹ˆλ‹€.",
64
+ theme="soft", # ν…Œλ§ˆ μ„€μ •
65
+ examples=[["ν•œκ΅­μ˜ μˆ˜λ„λŠ” μ–΄λ””μ•Ό?"], ["AI에 λŒ€ν•΄ 3μ€„λ‘œ μš”μ•½ν•΄μ€˜."]]
66
+ # retry_btn, undo_btn, clear_btn νŒŒλΌλ―Έν„°λŠ” ν˜„μž¬ Gradio λ²„μ „μ—μ„œ μ§€μ›λ˜μ§€ μ•Šμ•„ μ‚­μ œν–ˆμŠ΅λ‹ˆλ‹€.
 
 
67
  )
68
 
69
+ # --- 4. μ•± μ‹€ν–‰ ---
70
+ # .launch()둜 Spaceμ—μ„œ 앱을 μ‹€ν–‰μ‹œν‚΅λ‹ˆλ‹€.
71
+ chatbot_ui.launch()