Alfasign commited on
Commit
a499ce9
·
1 Parent(s): 58e2c4e
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openai
3
+ import re
4
+ import csv
5
+ import base64
6
+ from io import StringIO
7
+ import threading
8
+ from queue import Queue
9
+
10
+ st.title("EinfachChatProjekt")
11
+
12
+ api_key = st.sidebar.text_input("API Key:", value="sk-")
13
+ openai.api_key = api_key
14
+
15
+ show_notes = st.sidebar.checkbox("Show Notes", value="TRUE")
16
+ data_section = st.sidebar.text_area("CSV or Text Data:")
17
+ paste_data = st.sidebar.button("Paste Data")
18
+ num_concurrent_calls = st.sidebar.number_input("Concurrent Calls:", min_value=1, max_value=2000, value=50, step=1)
19
+ generate_all = st.sidebar.button("Generate All")
20
+ reset = st.sidebar.button("Reset")
21
+ add_row = st.sidebar.button("Add row")
22
+ model = st.sidebar.selectbox("Model:", ["gpt-4", "gpt-3.5-turbo"])
23
+ temperature = st.sidebar.slider("Temperature:", 0.0, 1.0, 0.6, step=0.01)
24
+ max_tokens = st.sidebar.number_input("Max Tokens:", min_value=1, max_value=8192, value=2000, step=1)
25
+ top_p = st.sidebar.slider("Top P:", 0.0, 1.0, 1.0, step=0.01)
26
+ system_message = st.sidebar.text_area("System Message:")
27
+ row_count = st.session_state.get("row_count", 1)
28
+
29
+ if add_row:
30
+ row_count += 1
31
+ st.session_state.row_count = row_count
32
+
33
+ if paste_data:
34
+ data = StringIO(data_section.strip())
35
+ reader = csv.reader(data, delimiter='\n', quotechar='"')
36
+ messages = [row[0] for row in reader]
37
+ if show_notes:
38
+ row_count = len(messages) // 2
39
+ for i in range(row_count):
40
+ st.session_state[f"note{i}"] = messages[i * 2]
41
+ st.session_state[f"message{i}"] = messages[i * 2 + 1]
42
+ else:
43
+ row_count = len(messages)
44
+ for i, message in enumerate(messages):
45
+ st.session_state[f"message{i}"] = message
46
+ st.session_state.row_count = row_count
47
+
48
+ if reset:
49
+ row_count = 1
50
+ st.session_state.row_count = row_count
51
+ for i in range(100): # Assuming a maximum of 100 rows
52
+ st.session_state[f"note{i}"] = ""
53
+ st.session_state[f"message{i}"] = ""
54
+ st.session_state[f"response{i}"] = ""
55
+ st.session_state[f"prompt_tokens{i}"] = 0
56
+ st.session_state[f"response_tokens{i}"] = 0
57
+ st.session_state[f"word_count{i}"] = 0
58
+
59
+ def generate_response(i, message):
60
+ try:
61
+ completion = openai.ChatCompletion.create(
62
+ model=model,
63
+ messages=[
64
+ {"role": "system", "content": system_message},
65
+ {"role": "user", "content": message}
66
+ ],
67
+ temperature=temperature,
68
+ max_tokens=max_tokens,
69
+ top_p=top_p
70
+ )
71
+
72
+ response = completion.choices[0].message.content
73
+ prompt_tokens = completion.usage['prompt_tokens']
74
+ response_tokens = completion.usage['total_tokens'] - prompt_tokens
75
+ word_count = len(re.findall(r'\w+', response))
76
+
77
+ return (i, response, prompt_tokens, response_tokens, word_count)
78
+
79
+ except Exception as e:
80
+ return (i, str(e), 0, 0, 0)
81
+
82
+ def worker(q, results):
83
+ for item in iter(q.get, None):
84
+ results.put(generate_response(*item))
85
+
86
+ class WorkerThread(threading.Thread):
87
+ def __init__(self, input_queue, output_queue):
88
+ threading.Thread.__init__(self)
89
+ self.input_queue = input_queue
90
+ self.output_queue = output_queue
91
+ self.daemon = True
92
+
93
+ def run(self):
94
+ while True:
95
+ i, message = self.input_queue.get()
96
+ try:
97
+ result = generate_response(i, message)
98
+ self.output_queue.put(result)
99
+ finally:
100
+ self.input_queue.task_done()
101
+
102
+ if generate_all:
103
+ jobs = Queue()
104
+ results = Queue()
105
+
106
+ workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)]
107
+
108
+ for worker in workers:
109
+ worker.start()
110
+
111
+ for i in range(row_count):
112
+ message = st.session_state.get(f"message{i}", "")
113
+ jobs.put((i, message))
114
+
115
+ jobs.join()
116
+
117
+ while not results.empty():
118
+ i, response, prompt_tokens, response_tokens, word_count = results.get()
119
+ st.session_state[f"response{i}"] = response
120
+ st.session_state[f"prompt_tokens{i}"] = prompt_tokens
121
+ st.session_state[f"response_tokens{i}"] = response_tokens
122
+ st.session_state[f"word_count{i}"] = word_count
123
+
124
+ def create_download_link(text, filename):
125
+ b64 = base64.b64encode(text.encode()).decode()
126
+ href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">Download {filename}</a>'
127
+ return href
128
+
129
+ for i in range(row_count):
130
+ if show_notes:
131
+ st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", ""))
132
+ col1, col2 = st.columns(2)
133
+
134
+ with col1:
135
+ message = st.text_area(f"Message {i + 1}:", key=f"message{i}", value=st.session_state.get(f"message{i}", ""))
136
+
137
+ if st.button(f"Generate Response {i + 1}") and not st.session_state.get(f"response{i}", ""):
138
+ response, prompt_tokens, response_tokens, word_count = generate_response(i, message)
139
+ st.session_state[f"response{i}"] = response
140
+ st.session_state[f"prompt_tokens{i}"] = prompt_tokens
141
+ st.session_state[f"response_tokens{i}"] = response_tokens
142
+ st.session_state[f"word_count{i}"] = word_count
143
+
144
+ with col2:
145
+ st.text_area(f"Response {i + 1}:", value=st.session_state.get(f"response{i}", ""))
146
+ st.write(f"Tokens: {st.session_state.get(f'prompt_tokens{i}', 0)} / {st.session_state.get(f'response_tokens{i}', 0)} + Words: {st.session_state.get(f'word_count{i}', 0)}")
147
+
148
+ responses_text = "\n\n".join([f"{st.session_state.get(f'note{i}', '')}\n{st.session_state.get(f'response{i}', '')}" for i in range(row_count) if show_notes] + [st.session_state.get(f"response{i}", "") for i in range(row_count) if not show_notes])
149
+ download_filename = "GPT-4 Responses.txt"
150
+ download_link = create_download_link(responses_text, download_filename)
151
+ st.markdown(download_link, unsafe_allow_html=True)