xzyao commited on
Commit
0a53800
1 Parent(s): 2ab9fff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -31
app.py CHANGED
@@ -1,8 +1,10 @@
1
 
 
2
  import streamlit as st
3
  import requests
4
  import time
5
  from ast import literal_eval
 
6
 
7
  def to_md(text):
8
  return text.replace("\n", "<br />")
@@ -38,41 +40,23 @@ def infer(
38
  temperature = 0.01
39
 
40
  my_post_dict = {
41
- "type": "general",
42
- "payload": {
43
- "max_tokens": int(max_new_tokens),
44
- "n": int(num_completions),
45
- "temperature": float(temperature),
46
- "top_p": float(top_p),
47
- "model": model_name_map[model_name],
48
- "prompt": [prompt],
49
- "request_type": "language-model-inference",
50
- "stop": stop,
51
- "best_of": 1,
52
- "echo": False,
53
- "seed": int(seed),
54
- "prompt_embedding": False,
55
- },
56
- "returned_payload": {},
57
- "status": "submitted",
58
- "source": "dalle",
59
  }
60
-
61
- job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
62
-
63
- for i in range(100):
64
-
65
- time.sleep(0.5)
66
- ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
67
- if ret['status'] == 'finished':
68
- break
69
-
70
- generated_text = ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
71
 
72
  for stop_word in stop:
73
  if stop_word in generated_text:
74
  generated_text = generated_text[:generated_text.find(stop_word)]
75
-
76
  st.session_state.updated = True
77
 
78
  return generated_text
@@ -120,7 +104,7 @@ def main():
120
  st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"
121
 
122
  if 'temperature' not in st.session_state:
123
- st.session_state.temperature = "0.1"
124
 
125
  if 'top_p' not in st.session_state:
126
  st.session_state.top_p = "1.0"
 
1
 
2
+
3
  import streamlit as st
4
  import requests
5
  import time
6
  from ast import literal_eval
7
+ from datetime import datetime
8
 
9
  def to_md(text):
10
  return text.replace("\n", "<br />")
 
40
  temperature = 0.01
41
 
42
  my_post_dict = {
43
+ "model": "Together-gpt-JT-6B-v1",
44
+ "prompt": prompt,
45
+ "top_p": top_p,
46
+ "top_k": top_k,
47
+ "temperature": temperature,
48
+ "max_tokens": max_new_tokens,
49
+ "stop": stop,
 
 
 
 
 
 
 
 
 
 
 
50
  }
51
+ print(f"send: {datetime.now()}")
52
+ response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json()
53
+ generated_text = response['output']['choices'][0]['text']
54
+ print(f"recv: {datetime.now()}")
 
 
 
 
 
 
 
55
 
56
  for stop_word in stop:
57
  if stop_word in generated_text:
58
  generated_text = generated_text[:generated_text.find(stop_word)]
59
+
60
  st.session_state.updated = True
61
 
62
  return generated_text
 
104
  st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"
105
 
106
  if 'temperature' not in st.session_state:
107
+ st.session_state.temperature = "0.8"
108
 
109
  if 'top_p' not in st.session_state:
110
  st.session_state.top_p = "1.0"