chansung commited on
Commit
a1430d2
1 Parent(s): dd486e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -1
app.py CHANGED
@@ -1,9 +1,23 @@
1
  import gradio as gr
2
 
 
 
 
3
  from styles import MODEL_SELECTION_CSS
4
  from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
5
  from templates import templates
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  ex_file = open("examples.txt", "r")
8
  examples = ex_file.read().split("\n")
9
  ex_btns = []
@@ -36,10 +50,28 @@ def fill_up_placeholders(txt):
36
  "" if len(placeholders) >= 1 else txt
37
  )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
40
  with gr.Column() as chat_view:
41
  idx = gr.State(0)
42
- chat_state = gr.State()
 
 
43
  local_data = gr.JSON({}, visible=False)
44
 
45
  with gr.Row():
@@ -169,4 +201,10 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
169
  visible=False
170
  )
171
 
 
 
 
 
 
 
172
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from llama2 import GradioLLaMA2ChatPPManager
4
+ from llama2 import gen_text
5
+
6
  from styles import MODEL_SELECTION_CSS
7
  from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
8
  from templates import templates
9
 
10
+ from pingpong.context import CtxLastWindowStrategy
11
+
12
+ TOKEN = os.getenv('HF_TOKEN')
13
+ MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf'
14
+
15
+ def build_prompts(ppmanager, global_context, win_size=3):
16
+ dummy_ppm = copy.deepcopy(ppmanager)
17
+ dummy_ppm.ctx = global_context
18
+ lws = CtxLastWindowStrategy(win_size)
19
+ return lws(dummy_ppm)
20
+
21
  ex_file = open("examples.txt", "r")
22
  examples = ex_file.read().split("\n")
23
  ex_btns = []
 
50
  "" if len(placeholders) >= 1 else txt
51
  )
52
 
53
+ async def chat_stream(idx, local_data, instruction_txtbox, chat_state):
54
+ res = [
55
+ chat_state["ppmanager_type"].from_json(json.dumps(ppm))
56
+ for ppm in local_data
57
+ ]
58
+
59
+ ppm = res[idx]
60
+ ppm.add_pingpong(
61
+ PingPong(instruction_txtbox, "")
62
+ )
63
+ prompt = build_prompts(ppm, "global context", 3)
64
+ for result in await gen_text(prompt, hf_model=MODEL_ID, hf_token=TOKEN):
65
+ ppm.append_pong(result)
66
+ yield ppm.build_uis(), str(res)
67
+
68
+
69
  with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
70
  with gr.Column() as chat_view:
71
  idx = gr.State(0)
72
+ chat_state = gr.State({
73
+ "ppmanager_type": GradioLLaMA2ChatPPManager
74
+ })
75
  local_data = gr.JSON({}, visible=False)
76
 
77
  with gr.Row():
 
201
  visible=False
202
  )
203
 
204
+ instruction_txtbox.submit(
205
+ chat_stream,
206
+ [idx, local_data, instruction_txtbox, chat_state]
207
+ [chatbot, local_data]
208
+ )
209
+
210
  demo.launch()