CogwiseAI commited on
Commit
322c3d1
1 Parent(s): e7dc722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -2
app.py CHANGED
@@ -1,5 +1,258 @@
1
  import streamlit as st
 
 
 
 
 
2
  import pandas as pd
3
- import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- model = pickle.load(open('model_saved.pkl', 'rb'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import uuid
3
+ import sys
4
+ import requests
5
+ from peft import *
6
+ import bitsandbytes as bnb
7
  import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+ import transformers
11
+ from datasets import load_dataset
12
+ from huggingface_hub import notebook_login
13
+ from peft import (
14
+ LoraConfig,
15
+ PeftConfig,
16
+ get_peft_model,
17
+ prepare_model_for_kbit_training,
18
+ )
19
+ from transformers import (
20
+ AutoConfig,
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ BitsAndBytesConfig,
24
+ )
25
 
26
+
27
+ USER_ICON = "images/user-icon.png"
28
+ AI_ICON = "images/ai-icon.png"
29
+ MAX_HISTORY_LENGTH = 5
30
+
31
+ if 'user_id' in st.session_state:
32
+ user_id = st.session_state['user_id']
33
+ else:
34
+ user_id = str(uuid.uuid4())
35
+ st.session_state['user_id'] = user_id
36
+
37
+ if 'chat_history' not in st.session_state:
38
+ st.session_state['chat_history'] = []
39
+
40
+ if "chats" not in st.session_state:
41
+ st.session_state.chats = [
42
+ {
43
+ 'id': 0,
44
+ 'question': '',
45
+ 'answer': ''
46
+ }
47
+ ]
48
+
49
+ if "questions" not in st.session_state:
50
+ st.session_state.questions = []
51
+
52
+ if "answers" not in st.session_state:
53
+ st.session_state.answers = []
54
+
55
+ if "input" not in st.session_state:
56
+ st.session_state.input = ""
57
+
58
+ st.markdown("""
59
+ <style>
60
+ .block-container {
61
+ padding-top: 32px;
62
+ padding-bottom: 32px;
63
+ padding-left: 0;
64
+ padding-right: 0;
65
+ }
66
+ .element-container img {
67
+ background-color: #000000;
68
+ }
69
+
70
+ .main-header {
71
+ font-size: 24px;
72
+ }
73
+ </style>
74
+ """, unsafe_allow_html=True)
75
+
76
+ def write_top_bar():
77
+ col1, col2, col3 = st.columns([1,10,2])
78
+ with col1:
79
+ st.image(AI_ICON, use_column_width='always')
80
+ with col2:
81
+ header = "Cogwise Intelligent Assistant"
82
+ st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True)
83
+ with col3:
84
+ clear = st.button("Clear Chat")
85
+ return clear
86
+
87
+ clear = write_top_bar()
88
+
89
+ if clear:
90
+ st.session_state.questions = []
91
+ st.session_state.answers = []
92
+ st.session_state.input = ""
93
+ st.session_state["chat_history"] = []
94
+
95
+ def handle_input():
96
+ input = st.session_state.input
97
+ question_with_id = {
98
+ 'question': input,
99
+ 'id': len(st.session_state.questions)
100
+ }
101
+ st.session_state.questions.append(question_with_id)
102
+
103
+ chat_history = st.session_state["chat_history"]
104
+ if len(chat_history) == MAX_HISTORY_LENGTH:
105
+ chat_history = chat_history[:-1]
106
+
107
+ # api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise"
108
+ # api_request_data = {"question": input, "session": user_id}
109
+ # api_response = requests.post(api_url, json=api_request_data)
110
+ # result = api_response.json()
111
+
112
+ # answer = result['answer']
113
+ # !pip install -Uqqq pip --progress-bar off
114
+ # !pip install -qqq bitsandbytes == 0.39.0
115
+ # !pip install -qqqtorch --2.0.1 --progress-bar off
116
+ # !pip install -qqq -U git + https://github.com/huggingface/transformers.git@e03a9cc --progress-bar off
117
+ # !pip install -qqq -U git + https://github.com/huggingface/peft.git@42a184f --progress-bar off
118
+ # !pip install -qqq -U git + https://github.com/huggingface/accelerate.git@c9fbb71 --progress-bar off
119
+ # !pip install -qqq datasets == 2.12.0 --progress-bar off
120
+ # !pip install -qqq loralib == 0.1.1 --progress-bar off
121
+ # !pip install einops
122
+
123
+ import os
124
+ # from pprint import pprint
125
+ # import json
126
+
127
+
128
+
129
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
130
+
131
+ # notebook_login()
132
+ # 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
133
+
134
+ # """### Load dataset"""
135
+
136
+ from datasets import load_dataset
137
+
138
+ dataset_name = "nisaar/Lawyer_GPT_India"
139
+ # dataset_name = "patrick11434/TEST_LLM_DATASET"
140
+ dataset = load_dataset(dataset_name, split="train")
141
+
142
+ # """## Load adapters from the Hub
143
+
144
+ # You can also directly load adapters from the Hub using the commands below:
145
+ # """
146
+
147
+
148
+ # change peft_model_id
149
+ bnb_config = BitsAndBytesConfig(
150
+ load_in_4bit=True,
151
+ load_4bit_use_double_quant=True,
152
+ bnb_4bit_quant_type="nf4",
153
+ bnb_4bit_compute_dtype=torch.bfloat16,
154
+ )
155
+
156
+ peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts"
157
+ config = PeftConfig.from_pretrained(peft_model_id)
158
+ model = AutoModelForCausalLM.from_pretrained(
159
+ config.base_model_name_or_path,
160
+ return_dict=True,
161
+ quantization_config=bnb_config,
162
+ device_map="auto",
163
+ trust_remote_code=True,
164
+ )
165
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
166
+ tokenizer.pad_token = tokenizer.eos_token
167
+
168
+ model = PeftModel.from_pretrained(model, peft_model_id)
169
+
170
+ """## Inference
171
+
172
+ You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`.
173
+ """
174
+
175
+ generation_config = model.generation_config
176
+ generation_config.max_new_tokens = 200
177
+ generation_config_temperature = 1
178
+ generation_config.top_p = 0.7
179
+ generation_config.num_return_sequences = 1
180
+ generation_config.pad_token_id = tokenizer.eos_token_id
181
+ generation_config_eod_token_id = tokenizer.eos_token_id
182
+
183
+ DEVICE = "cuda:0"
184
+
185
+ # Commented out IPython magic to ensure Python compatibility.
186
+ # %%time
187
+ # prompt = f"""
188
+ # <human>: Who appoints the Chief Justice of India?
189
+ # <assistant>:
190
+ # """.strip()
191
+ #
192
+ # encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
193
+ # with torch.inference_mode():
194
+ # outputs = model.generate(
195
+ # input_ids=encoding.attention_mask,
196
+ # generation_config=generation_config,
197
+ # )
198
+ # print(tokenizer.decode(outputs[0],skip_special_tokens=True))
199
+
200
+ def generate_response(question: str) -> str:
201
+ prompt = f"""
202
+ <human>: {question}
203
+ <assistant>:
204
+ """.strip()
205
+ encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
206
+ with torch.inference_mode():
207
+ outputs = model.generate(
208
+ input_ids=encoding.input_ids,
209
+ attention_mask=encoding.attention_mask,
210
+ generation_config=generation_config,
211
+ )
212
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
213
+
214
+ assistant_start = '<assistant>:'
215
+ response_start = response.find(assistant_start)
216
+ return response[response_start + len(assistant_start):].strip()
217
+
218
+ # prompt = "Debate the merits and demerits of introducing simultaneous elections in India?"
219
+ prompt=input
220
+ answer=generate_response(prompt)
221
+ print(answer)
222
+
223
+ # answer='Yes'
224
+ chat_history.append((input, answer))
225
+
226
+ st.session_state.answers.append({
227
+ 'answer': answer,
228
+ 'id': len(st.session_state.questions)
229
+ })
230
+ st.session_state.input = ""
231
+
232
+ def write_user_message(md):
233
+ col1, col2 = st.columns([1,12])
234
+
235
+ with col1:
236
+ st.image(USER_ICON, use_column_width='always')
237
+ with col2:
238
+ st.warning(md['question'])
239
+
240
+ def render_answer(answer):
241
+ col1, col2 = st.columns([1,12])
242
+ with col1:
243
+ st.image(AI_ICON, use_column_width='always')
244
+ with col2:
245
+ st.info(answer)
246
+
247
+ def write_chat_message(md, q):
248
+ chat = st.container()
249
+ with chat:
250
+ render_answer(md['answer'])
251
+
252
+ with st.container():
253
+ for (q, a) in zip(st.session_state.questions, st.session_state.answers):
254
+ write_user_message(q)
255
+ write_chat_message(a, q)
256
+
257
+ st.markdown('---')
258
+ input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)