Spaces:
Sleeping
Sleeping
Vinitha3699
commited on
Commit
•
4c095e1
1
Parent(s):
5791a97
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import mysql.connector
|
3 |
+
import bcrypt
|
4 |
+
import datetime
|
5 |
+
import re
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import pytz
|
9 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
10 |
+
|
11 |
+
# STreamlit:
|
12 |
+
|
13 |
+
col1, col2 = st.columns([1, 2])
|
14 |
+
|
15 |
+
with col1:
|
16 |
+
st.image("Screenshot 2024-06-29 154050.png", width=150)
|
17 |
+
|
18 |
+
with col2:
|
19 |
+
st.markdown('''<h1 style="color:#B22222;">Guvi-GPT Text Generator!</h1>''', unsafe_allow_html=True)
|
20 |
+
|
21 |
+
# Sql Connection:
|
22 |
+
|
23 |
+
connection = mysql.connector.connect(
|
24 |
+
host = "gateway01.ap-southeast-1.prod.aws.tidbcloud.com",
|
25 |
+
port = 4000,
|
26 |
+
user = "2ZyEJgr7d9Z3zUr.root",
|
27 |
+
password = "Q970rGTYSeCyFYk2")
|
28 |
+
|
29 |
+
mycursor = connection.cursor(buffered=True)
|
30 |
+
|
31 |
+
mycursor.execute("CREATE DATABASE IF NOT EXISTS Guvi_GPT")
|
32 |
+
mycursor.execute('USE Guvi_GPT')
|
33 |
+
|
34 |
+
mycursor.execute('''CREATE TABLE IF NOT EXISTS User_data
|
35 |
+
(id INT AUTO_INCREMENT PRIMARY KEY,
|
36 |
+
username VARCHAR(50) UNIQUE NOT NULL,
|
37 |
+
password VARCHAR(255) NOT NULL,
|
38 |
+
email VARCHAR(255) UNIQUE NOT NULL,
|
39 |
+
registered_date TIMESTAMP,
|
40 |
+
last_login TIMESTAMP)''')
|
41 |
+
|
42 |
+
def username_exists(username):
|
43 |
+
mycursor.execute("SELECT * FROM User_data WHERE username = %s", (username,))
|
44 |
+
return mycursor.fetchone() is not None
|
45 |
+
|
46 |
+
def email_exists(email):
|
47 |
+
mycursor.execute("SELECT * FROM User_data WHERE email = %s", (email,))
|
48 |
+
return mycursor.fetchone() is not None
|
49 |
+
|
50 |
+
def is_valid_email(email):
|
51 |
+
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
52 |
+
return re.match(pattern, email) is not None
|
53 |
+
|
54 |
+
def create_user(username, password, email, registered_date):
|
55 |
+
if username_exists(username):
|
56 |
+
return 'username_exists'
|
57 |
+
|
58 |
+
if email_exists(email):
|
59 |
+
return 'email_exists'
|
60 |
+
|
61 |
+
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
|
62 |
+
mycursor.execute(
|
63 |
+
"INSERT INTO User_data (username, password, email, registered_date) VALUES (%s, %s, %s, %s)",
|
64 |
+
(username, hashed_password, email, registered_date)
|
65 |
+
)
|
66 |
+
connection.commit()
|
67 |
+
return 'success'
|
68 |
+
|
69 |
+
def verify_user(username, password):
|
70 |
+
mycursor.execute("SELECT password FROM User_data WHERE username = %s", (username,))
|
71 |
+
record = mycursor.fetchone()
|
72 |
+
if record and bcrypt.checkpw(password.encode('utf-8'), record[0].encode('utf-8')):
|
73 |
+
mycursor.execute("UPDATE User_data SET last_login = %s WHERE username = %s", (datetime.datetime.now(pytz.timezone('Asia/Kolkata')), username))
|
74 |
+
connection.commit()
|
75 |
+
return True
|
76 |
+
return False
|
77 |
+
|
78 |
+
def reset_password(username, new_password):
|
79 |
+
hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt())
|
80 |
+
mycursor.execute(
|
81 |
+
"UPDATE User_data SET password = %s WHERE username = %s",
|
82 |
+
(hashed_password, username)
|
83 |
+
)
|
84 |
+
connection.commit()
|
85 |
+
|
86 |
+
# Load the fine-tuned model and tokenizer
|
87 |
+
model_name_or_path = "fine_tuned_model"
|
88 |
+
model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
|
89 |
+
|
90 |
+
token_name_or_path = "fine_tuned_model"
|
91 |
+
tokenizer = GPT2Tokenizer.from_pretrained(token_name_or_path)
|
92 |
+
|
93 |
+
# Set the pad_token to eos_token if it's not already set
|
94 |
+
if tokenizer.pad_token is None:
|
95 |
+
tokenizer.pad_token = tokenizer.eos_token
|
96 |
+
|
97 |
+
# Move the model to GPU if available
|
98 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
99 |
+
model.to(device)
|
100 |
+
|
101 |
+
# Define the text generation function
|
102 |
+
def generate_text(model, tokenizer, seed_text, max_length=100, temperature=1.0, num_return_sequences=1):
|
103 |
+
# Tokenize the input text with padding
|
104 |
+
inputs = tokenizer(seed_text, return_tensors='pt', padding=True, truncation=True)
|
105 |
+
|
106 |
+
input_ids = inputs['input_ids'].to(device)
|
107 |
+
attention_mask = inputs['attention_mask'].to(device)
|
108 |
+
|
109 |
+
# Generate text
|
110 |
+
with torch.no_grad():
|
111 |
+
output = model.generate(
|
112 |
+
input_ids,
|
113 |
+
attention_mask=attention_mask,
|
114 |
+
max_length=max_length,
|
115 |
+
temperature=temperature,
|
116 |
+
num_return_sequences=num_return_sequences,
|
117 |
+
do_sample=True,
|
118 |
+
top_k=50,
|
119 |
+
top_p=0.01,
|
120 |
+
pad_token_id=tokenizer.eos_token_id # Ensure padding token is set to eos_token_id
|
121 |
+
)
|
122 |
+
|
123 |
+
# Decode the generated text
|
124 |
+
generated_texts = []
|
125 |
+
for i in range(num_return_sequences):
|
126 |
+
generated_text = tokenizer.decode(output[i], skip_special_tokens=True)
|
127 |
+
generated_texts.append(generated_text)
|
128 |
+
|
129 |
+
return generated_texts
|
130 |
+
|
131 |
+
# Session state management
|
132 |
+
if 'sign_up_successful' not in st.session_state:
|
133 |
+
st.session_state.sign_up_successful = False
|
134 |
+
if 'login_successful' not in st.session_state:
|
135 |
+
st.session_state.login_successful = False
|
136 |
+
if 'reset_password' not in st.session_state:
|
137 |
+
st.session_state.reset_password = False
|
138 |
+
if 'username' not in st.session_state:
|
139 |
+
st.session_state.username = ''
|
140 |
+
if 'current_page' not in st.session_state:
|
141 |
+
st.session_state.current_page = 'login'
|
142 |
+
|
143 |
+
def home_page():
|
144 |
+
st.markdown(f"""# <span style="color:#006400">Welcome, {st.session_state.username}!</span>""", unsafe_allow_html=True)
|
145 |
+
st.info('''**Disclaimer:** This application utilizes a GPT-based model for generating responses. While generating responses, it may produce errors or inaccuracies. So, Users are encouraged to verify any critical information independently.
|
146 |
+
**Most importantly, This model is not affiliated with or endorsed by Official GUVI EdTech Company.**''')
|
147 |
+
|
148 |
+
# Text generation section
|
149 |
+
seed_text = st.text_input("**Enter text:**")
|
150 |
+
max_length = st.number_input("**Length of Words**", min_value=10, max_value=500, value=100)
|
151 |
+
|
152 |
+
if st.button("Generate"):
|
153 |
+
with st.spinner("Generating..."):
|
154 |
+
generated_texts = generate_text(model, tokenizer, seed_text, max_length, temperature=0.000001, num_return_sequences=1)
|
155 |
+
for i, text in enumerate(generated_texts):
|
156 |
+
st.write(f"**Text Generated :**\n{text}\n")
|
157 |
+
|
158 |
+
|
159 |
+
def login():
|
160 |
+
st.subheader(':green[**Login**]')
|
161 |
+
with st.form(key='login', clear_on_submit=True):
|
162 |
+
username = st.text_input(label='Username', placeholder='Enter Username')
|
163 |
+
password = st.text_input(label='Password', placeholder='Enter Password', type='password')
|
164 |
+
if st.form_submit_button('**Login**'):
|
165 |
+
if not username or not password:
|
166 |
+
st.error("Enter all the Credentials")
|
167 |
+
elif verify_user(username, password):
|
168 |
+
st.session_state.login_successful = True
|
169 |
+
st.session_state.username = username
|
170 |
+
st.session_state.current_page = 'home'
|
171 |
+
st.rerun()
|
172 |
+
else:
|
173 |
+
st.error("Incorrect Username or Password. Sign Up!, if you don't have an account")
|
174 |
+
if not st.session_state.login_successful:
|
175 |
+
c1, c2 = st.columns(2)
|
176 |
+
with c1:
|
177 |
+
st.write(":red[New user?]")
|
178 |
+
if st.button('**Sign Up**'):
|
179 |
+
st.session_state.current_page = 'sign_up'
|
180 |
+
st.rerun()
|
181 |
+
with c2:
|
182 |
+
st.write(":red[Forgot Password?]")
|
183 |
+
if st.button('**Reset Password**'):
|
184 |
+
st.session_state.current_page = 'reset_password'
|
185 |
+
st.rerun()
|
186 |
+
|
187 |
+
def signup():
|
188 |
+
st.subheader(':red[**Sign Up**]')
|
189 |
+
with st.form(key='signup', clear_on_submit=True):
|
190 |
+
email = st.text_input(label='Email', placeholder='Enter Your Email')
|
191 |
+
username = st.text_input(label='Username', placeholder='Enter Your Username')
|
192 |
+
password = st.text_input(label='Password', placeholder='Enter Your Password', type='password')
|
193 |
+
re_password = st.text_input(label='Confirm Password', placeholder='Confirm Your Password', type='password')
|
194 |
+
registered_date = datetime.datetime.now(pytz.timezone('Asia/Kolkata'))
|
195 |
+
|
196 |
+
if st.form_submit_button('**Sign Up**'):
|
197 |
+
if not email or not username or not password or not re_password:
|
198 |
+
st.error("Enter all the Credentials")
|
199 |
+
elif not is_valid_email(email):
|
200 |
+
st.error("Enter a valid email address")
|
201 |
+
elif len(password) <= 3:
|
202 |
+
st.error("Password too short")
|
203 |
+
elif password != re_password:
|
204 |
+
st.error("Passwords do not match! Please Re-enter")
|
205 |
+
else:
|
206 |
+
result = create_user(username, password, email, registered_date)
|
207 |
+
if result == 'username_exists':
|
208 |
+
st.error("Username already registered! Retry Login.")
|
209 |
+
elif result == 'email_exists':
|
210 |
+
st.error("Email already registered. Retry Login.")
|
211 |
+
elif result == 'success':
|
212 |
+
st.success(f"Username {username} has been successfully created! Kindly login.")
|
213 |
+
st.session_state.sign_up_successful = True
|
214 |
+
else:
|
215 |
+
st.error("Failed to create user. Try again later.")
|
216 |
+
|
217 |
+
if st.session_state.sign_up_successful:
|
218 |
+
if st.button('**Login**'):
|
219 |
+
st.session_state.current_page = 'login'
|
220 |
+
st.rerun()
|
221 |
+
|
222 |
+
def reset_password_page():
|
223 |
+
st.subheader(':bee[Reset Password]')
|
224 |
+
with st.form(key='reset_password', clear_on_submit=True):
|
225 |
+
username = st.text_input(label='Username', value='')
|
226 |
+
new_password = st.text_input(label='New Password', type='password')
|
227 |
+
re_password = st.text_input(label='Confirm New Password', type='password')
|
228 |
+
|
229 |
+
if st.form_submit_button('Reset Password'):
|
230 |
+
if not username:
|
231 |
+
st.error("Enter your username.")
|
232 |
+
elif not username_exists(username):
|
233 |
+
st.error("Username not found. Enter a valid username")
|
234 |
+
elif not new_password or not re_password:
|
235 |
+
st.error("Enter all the Credentials.")
|
236 |
+
elif len(new_password) <= 3:
|
237 |
+
st.error("Password too short")
|
238 |
+
elif new_password != re_password:
|
239 |
+
st.error("Passwords do not match. Please Re-enter")
|
240 |
+
else:
|
241 |
+
reset_password(username, new_password)
|
242 |
+
st.success("Password has been Resetted successfully! Kindly Login")
|
243 |
+
st.session_state.current_page = 'login'
|
244 |
+
|
245 |
+
st.write('**Get back to Login page!**')
|
246 |
+
if st.button('Login'):
|
247 |
+
st.session_state.current_page = 'login'
|
248 |
+
st.rerun()
|
249 |
+
|
250 |
+
|
251 |
+
# Display appropriate page based on session state
|
252 |
+
if st.session_state.current_page == 'home':
|
253 |
+
home_page()
|
254 |
+
elif st.session_state.current_page == 'login':
|
255 |
+
login()
|
256 |
+
elif st.session_state.current_page == 'sign_up':
|
257 |
+
signup()
|
258 |
+
elif st.session_state.current_page == 'reset_password':
|
259 |
+
reset_password_page()
|