File size: 2,628 Bytes
0d215ca
 
 
e532db6
0d215ca
7f5cbab
d7570a5
 
 
 
 
 
b2703de
 
0d215ca
3442116
 
 
 
0d215ca
 
 
bbe538b
d7570a5
0d215ca
 
3442116
0d215ca
 
b2703de
0d215ca
 
d7570a5
0d215ca
 
 
 
 
 
 
0b3be54
0d215ca
bbe538b
0b3be54
 
0d215ca
 
 
 
 
 
 
 
 
0c8c7dc
0d215ca
a531b86
0d215ca
d7570a5
 
 
157e1ad
d7570a5
b2703de
e532db6
d7570a5
 
a531b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2703de
 
 
e532db6
b2703de
e532db6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import requests
import time
from ast import literal_eval

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0, 
          top_p=1.0,
          num_completions=1,
          seed=42,
          stop="\n"):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": int(num_completions),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": stop.split(";"),
            "best_of": 1,
            "echo": False,
            "seed": int(seed),
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(1)
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("GPT-JT")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
    stop = st.text_input('stop, split by;', repr('\n'))
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.markdown("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        generated_area.markdown(prompt)
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed, stop=literal_eval(stop),
        )
        generated_area.markdown(prompt + "_" + report_text + "_")