File size: 4,425 Bytes
45a9357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import os
import openai

def get_completion(client, model_id, messages, args):
    completion_args = {
        "model": model_id,
        "messages": messages,
        "frequency_penalty": args.frequency_penalty,
        "max_tokens": args.max_tokens,
        "n": args.n,
        "presence_penalty": args.presence_penalty,
        "seed": args.seed,
        "stop": args.stop,
        "stream": args.stream,
        "temperature": args.temperature,
        "top_p": args.top_p,
    }

    completion_args = {
        k: v for k, v in completion_args.items() if v is not None
    }

    try:
        response = client.chat.completions.create(**completion_args)
        return response
    except Exception as e:
        print(f"Error during API call: {e}")
        return None

# App title 
st.set_page_config(page_title="Turing Test")

# OpenAI Client Setup
with st.sidebar:
    st.title('🦙💬 Welcome to Turing Test')
    
    # Hardcoded API key
    openai_api_key = "super-secret-token"
    # st.success('API key already provided!', icon='✅')
    
    os.environ['OPENAI_API_KEY'] = openai_api_key
    
    openai.api_key = openai_api_key
    openai.api_base = "https://turingtest--example-vllm-openai-compatible-serve.modal.run/v1"

    client = openai.OpenAI(api_key=openai_api_key, base_url=openai.api_base)

    # Add system prompt input
    st.subheader('System Prompt')
    system_prompt = st.text_area("Enter a system prompt:", 
                                 "you are rolplaying as an old grandma",
                                 help="This message sets the behavior of the AI.")
    st.subheader('Models and parameters')
    selected_model = st.sidebar.selectbox('Choose a model', ['meta-llama/Meta-Llama-3.1-8B-Instruct'], key='selected_model')
    temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.8, step=0.1) 
    top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.95, step=0.01)
    max_length = st.sidebar.slider('max_length', min_value=32, max_value=1024, value=32, step=8)
    

# Store chat history  
if "messages" not in st.session_state.keys():
    st.session_state.messages = [
        {"role": "system", "content": system_prompt},
        {"role": "assistant", "content": "Hello!"}
    ]
    
# Display chat messages (excluding system message)
for message in st.session_state.messages[1:]:
    with st.chat_message(message["role"]):
        st.write(message["content"])

def clear_chat_history():
    st.session_state.messages = [
        {"role": "system", "content": system_prompt},
        {"role": "assistant", "content": "Hello!"}
    ]
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)

# Function for generating Llama2 response using OpenAI client API
def generate_llama2_response(prompt_input, model, temperature, top_p, max_length):
    
    class Args:
        def __init__(self):
            self.frequency_penalty = 0
            self.max_tokens = max_length
            self.n = 1 
            self.presence_penalty = 0
            self.seed = 42
            self.stop = None
            self.stream = False
            self.temperature = temperature
            self.top_p = top_p

    args = Args()
    
    # Update system message before each completion
    st.session_state.messages[0] = {"role": "system", "content": system_prompt}
    
    response = get_completion(client, model, st.session_state.messages, args)
    
    if response:
        return response.choices[0].message.content
    else:
        return "Sorry, there was an error generating a response."
        
# User-provided prompt 
if prompt := st.chat_input():
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)
        
# Generate response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = generate_llama2_response(prompt, selected_model, temperature, top_p, max_length)
            placeholder = st.empty()  
            full_response = ''
            for item in response:
                full_response += item  
                placeholder.markdown(full_response)
    message = {"role": "assistant", "content": full_response}
    st.session_state.messages.append(message)