csabakecskemeti commited on
Commit
669d949
1 Parent(s): d91af6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -8
app.py CHANGED
@@ -1,19 +1,45 @@
1
  import gradio as gr
2
- from llama_cpp import Llama
3
  import os
4
- import llama_cpp_python as llama
5
-
6
-
7
 
8
  sbc_host_url = os.environ['URL']
9
 
10
- # Create a llama client
11
- client = llama.Client(host="sbc_host_url", port=3344)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def chatty(prompt, messages):
15
- response = client.call("generate_text", prompt=prompt)
16
- return response
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  demo = gr.ChatInterface(
 
1
  import gradio as gr
 
2
  import os
3
+ import requests
4
+ import json
 
5
 
6
  sbc_host_url = os.environ['URL']
7
 
8
+ def get_completion(prompt:str, messages:str = '', n_predict=128):
9
+ prompt_templated = f'{messages}\n ### HUMAN:\n{prompt} \n ### ASSISTANT:'
10
+ headers = {
11
+ "Content-Type": "application/json"
12
+ }
13
+ data = {
14
+ "prompt": prompt_templated,
15
+ "n_predict": n_predict,
16
+ "stop": ["### HUMAN:", "### ASSISTANT:", "HUMAN"],
17
+ "stream": "True"
18
+ }
19
+
20
+ response = requests.post(sbc_host_url, headers=headers, data=json.dumps(data))
21
+
22
+ if response.status_code == 200:
23
+ return response.json()['content']
24
+ else:
25
+ response.raise_for_status()
26
 
27
 
28
  def chatty(prompt, messages):
29
+ print(prompt)
30
+ print(f'messages: {messages}')
31
+ past_messages = ''
32
+ if len(messages) > 0:
33
+ for idx, message in enumerate(messages):
34
+ print(f'idx: {idx}, message: {message}')
35
+ past_messages += f'\n### HUMAN: {message[0]}'
36
+ past_messages += f'\n### ASSISTANT: {message[1]}'
37
+
38
+
39
+ # past_messages = messages[0][0]
40
+ print(f'past_messages: {past_messages}')
41
+ messages = get_completion(prompt, past_messages)
42
+ return messages.split('### ASSISTANT:')[-1]
43
 
44
 
45
  demo = gr.ChatInterface(