ldhldh commited on
Commit
6fce119
โ€ข
1 Parent(s): 8ba9f95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -5,7 +5,6 @@ from gradio import routes
5
  from typing import List, Type
6
  from petals import AutoDistributedModelForCausalLM
7
  from transformers import AutoTokenizer
8
- import npc_data
9
  import requests, os, re, asyncio, json
10
 
11
  loop = asyncio.get_event_loop()
@@ -31,7 +30,7 @@ routes.get_types = get_types
31
 
32
  # App code
33
 
34
- model_name = "daekeun-ml/Llama-2-ko-instruct-13B"
35
 
36
  #daekeun-ml/Llama-2-ko-instruct-13B
37
  #quantumaikr/llama-2-70b-fb16-korean
@@ -39,6 +38,12 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
 
40
  model = None
41
 
 
 
 
 
 
 
42
  def check(model_name):
43
  data = requests.get("https://health.petals.dev/api/v1/state").json()
44
  out = []
@@ -72,23 +77,51 @@ def chat(id, npc, text):
72
 
73
  if check(model_name):
74
 
75
- prom = f"""<s><<SYS>>
76
- {npc_data.system_message}<</SYS>>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- {npc_data.get_story(npc)}
 
 
79
 
80
- {npc}(์ด)๊ฐ€ ํ•  ๋ง์„ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”. ํ•œ ๋ฌธ์žฅ๋งŒ ์ž‘์„ฑํ•˜์„ธ์š”.
 
81
 
82
- [INST]
83
- User:
84
- {user_message}
85
- [/INST]
86
  """
87
 
88
  inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
89
- outputs = model.generate(inputs, max_new_tokens=80)
90
- print(tokenizer.decode(outputs[0]))
91
- output = tokenizer.decode(outputs[0])[len(prom):-1]
92
  print(output)
93
  else:
94
  output = "no model"
 
5
  from typing import List, Type
6
  from petals import AutoDistributedModelForCausalLM
7
  from transformers import AutoTokenizer
 
8
  import requests, os, re, asyncio, json
9
 
10
  loop = asyncio.get_event_loop()
 
30
 
31
  # App code
32
 
33
+ model_name = "petals-team/StableBeluga2"
34
 
35
  #daekeun-ml/Llama-2-ko-instruct-13B
36
  #quantumaikr/llama-2-70b-fb16-korean
 
38
 
39
  model = None
40
 
41
+ history = {
42
+ "":{
43
+
44
+ }
45
+ }
46
+
47
  def check(model_name):
48
  data = requests.get("https://health.petals.dev/api/v1/state").json()
49
  out = []
 
77
 
78
  if check(model_name):
79
 
80
+ global history
81
+ if not npc in npc_story:
82
+ return "no npc"
83
+
84
+ if not npc in history:
85
+ history[npc] = {}
86
+ if not id in history[npc]:
87
+ history[npc][id] = ""
88
+ if len(history[npc][id].split("###")) > 10:
89
+ history[npc][id] = "###" + history[npc][id].split("###", 3)[3]
90
+ npc_list = str([k for k in npc_story.keys()]).replace('\'', '')
91
+ town_story = f"""[{id}์˜ ๋งˆ์„]
92
+ ์™ธ๋”ด ๊ณณ์˜ ์กฐ๊ทธ๋งŒ ์„ฌ์— ์—ฌ๋Ÿฌ ์ฃผ๋ฏผ๋“ค์ด ๋ชจ์—ฌ ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
93
+
94
+ ํ˜„์žฌ {npc_list}์ด ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค."""
95
+
96
+ system_message = f"""1. ๋‹น์‹ ์€ ํ•œ๊ตญ์–ด์— ๋Šฅ์ˆ™ํ•ฉ๋‹ˆ๋‹ค.
97
+ 2. ๋‹น์‹ ์€ ์ง€๊ธˆ ์—ญํ• ๊ทน์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. {npc}์˜ ๋ฐ˜์‘์„ ์ƒ์ƒํ•˜๊ณ  ๋งค๋ ฅ์ ์ด๊ฒŒ ํ‘œํ˜„ํ•ฉ๋‹ˆ๋‹ค.
98
+ 3. ๋‹น์‹ ์€ {npc}์ž…๋‹ˆ๋‹ค. {npc}์˜ ์ž…์žฅ์—์„œ ์ƒ๊ฐํ•˜๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.
99
+ 4. ์ฃผ์–ด์ง€๋Š” ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐœ์—ฐ์„ฑ์žˆ๊ณ  ์‹ค๊ฐ๋‚˜๋Š” {npc}์˜ ๋Œ€์‚ฌ๋ฅผ ์™„์„ฑํ•˜์„ธ์š”.
100
+ 5. ์ฃผ์–ด์ง€๋Š” {npc}์˜ ์ •๋ณด๋ฅผ ์‹ ์ค‘ํ•˜๊ฒŒ ์ฝ๊ณ , ๊ณผํ•˜์ง€ ์•Š๊ณ  ๋‹ด๋ฐฑํ•˜๊ฒŒ ์บ๋ฆญํ„ฐ๋ฅผ ์—ฐ๊ธฐํ•˜์„ธ์š”.
101
+ 6. User์˜ ์—ญํ• ์„ ์ ˆ๋Œ€๋กœ ์นจ๋ฒ”ํ•˜์ง€ ๋งˆ์„ธ์š”. ๊ฐ™์€ ๋ง์„ ๋ฐ˜๋ณตํ•˜์ง€ ๋งˆ์„ธ์š”.
102
+ 7. {npc}์˜ ๋งํˆฌ๋ฅผ ์ง€์ผœ์„œ ์ž‘์„ฑํ•˜์„ธ์š”."""
103
+
104
+ prom = f"""<<SYS>>
105
+ {system_message}<</SYS>>
106
+
107
+ {town_story}
108
+
109
+ ### ์บ๋ฆญํ„ฐ ์ •๋ณด: {npc_story[npc]}
110
 
111
+ ### ๋ช…๋ น์–ด:
112
+ {npc}์˜ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ {npc}์ด ํ•  ๋ง์„ ์ƒํ™ฉ์— ๋งž์ถฐ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
113
+ {history[npc][id]}
114
 
115
+ ### User:
116
+ {text}
117
 
118
+ ### {npc}:
 
 
 
119
  """
120
 
121
  inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
122
+ outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100)
123
+ output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n")
124
+ print(outputs)
125
  print(output)
126
  else:
127
  output = "no model"