|
|
|
import streamlit as st |
|
import mysql.connector |
|
import bcrypt |
|
import datetime |
|
import re |
|
import pytz |
|
import time |
|
|
|
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
|
|
|
|
icon='https://media.istockphoto.com/id/1413286466/vector/chat-bot-icon-robot-virtual-assistant-bot-vector-illustration.jpg?s=612x612&w=0&k=20&c=ZSG3eqGPDJgIgFUIuVxID64uVUF3eqM3LrrDWtaKses=' |
|
st.set_page_config(page_title='GUVI - GPT',page_icon=icon, menu_items={"about":'This streamlit application was developed by S.S.SHARMILA'}) |
|
|
|
|
|
mydb = mysql.connector.connect( |
|
host='host url', |
|
port=4000, |
|
user="4TZNT9dUm5s8BMa.root", |
|
password='password' |
|
) |
|
mycursor = mydb.cursor(buffered=True) |
|
|
|
|
|
mycursor.execute("CREATE DATABASE IF NOT EXISTS GUVI_DB") |
|
mycursor.execute('USE GUVI_DB') |
|
|
|
|
|
mycursor.execute('''CREATE TABLE IF NOT EXISTS users ( |
|
id INT AUTO_INCREMENT PRIMARY KEY, |
|
username VARCHAR(50) UNIQUE NOT NULL, |
|
password VARCHAR(255) NOT NULL, |
|
email VARCHAR(255) UNIQUE NOT NULL, |
|
registered_date TIMESTAMP, |
|
last_login TIMESTAMP |
|
);''') |
|
|
|
|
|
def username_exists(username): |
|
mycursor.execute("SELECT * FROM users WHERE username = %s", (username,)) |
|
return mycursor.fetchone() is not None |
|
|
|
|
|
def email_exists(email): |
|
mycursor.execute("SELECT * FROM users WHERE email = %s", (email,)) |
|
return mycursor.fetchone() is not None |
|
|
|
|
|
def is_valid_email(email): |
|
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' |
|
return re.match(pattern, email) is not None |
|
|
|
|
|
def create_user(username, password, email): |
|
if username_exists(username): |
|
return 'username_exists' |
|
|
|
if email_exists(email): |
|
return 'email_exists' |
|
|
|
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) |
|
registered_date = datetime.datetime.now(pytz.timezone('Asia/Kolkata')) |
|
|
|
|
|
mycursor.execute( |
|
"INSERT INTO users (username, password, email, registered_date) VALUES (%s, %s, %s, %s)", |
|
(username, hashed_password, email, registered_date) |
|
) |
|
mydb.commit() |
|
return 'success' |
|
|
|
|
|
def verify_user(username, password): |
|
mycursor.execute("SELECT password FROM users WHERE username = %s", (username,)) |
|
record = mycursor.fetchone() |
|
if record and bcrypt.checkpw(password.encode('utf-8'), record[0].encode('utf-8')): |
|
|
|
mycursor.execute("UPDATE users SET last_login = %s WHERE username = %s", (datetime.datetime.now(pytz.timezone('Asia/Kolkata')), username)) |
|
mydb.commit() |
|
return True |
|
return False |
|
|
|
|
|
def reset_password(username, new_password): |
|
hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()) |
|
|
|
mycursor.execute( |
|
"UPDATE users SET password = %s WHERE username = %s", |
|
(hashed_password, username) |
|
) |
|
mydb.commit() |
|
|
|
|
|
if 'sign_up_successful' not in st.session_state: |
|
st.session_state.sign_up_successful = False |
|
if 'login_successful' not in st.session_state: |
|
st.session_state.login_successful = False |
|
if 'reset_password' not in st.session_state: |
|
st.session_state.reset_password = False |
|
if 'username' not in st.session_state: |
|
st.session_state.username = '' |
|
if 'current_page' not in st.session_state: |
|
st.session_state.current_page = 'login' |
|
|
|
|
|
|
|
def login(): |
|
with st.form(key='login', clear_on_submit=True): |
|
st.subheader(':blue[**Login**]') |
|
st.write("Enter your username and password below.") |
|
|
|
|
|
username = st.text_input(label='Username', placeholder='Enter Your Username') |
|
password = st.text_input(label='Password', placeholder='Enter Your Password', type='password') |
|
|
|
if st.form_submit_button('Login'): |
|
if not username or not password: |
|
st.error("Please fill out all fields.") |
|
elif verify_user(username, password): |
|
st.session_state.login_successful = True |
|
st.session_state.username = username |
|
st.session_state.current_page = 'home' |
|
st.rerun() |
|
else: |
|
st.error("Incorrect username or password. If you don't have an account, please sign up.") |
|
|
|
|
|
if not st.session_state.login_successful: |
|
c1, c2 = st.columns(2) |
|
with c1: |
|
st.write(":red[New user?]") |
|
if st.button('Sign Up'): |
|
st.session_state.current_page = 'sign_up' |
|
st.rerun() |
|
with c2: |
|
st.write(":red[Forgot Password?]") |
|
if st.button('Reset Password'): |
|
st.session_state.current_page = 'reset_password' |
|
st.rerun() |
|
|
|
|
|
|
|
def signup(): |
|
with st.form(key='signup', clear_on_submit=True): |
|
st.subheader(':blue[**Sign Up**]') |
|
st.write("Enter the required fields to create a new account.") |
|
|
|
|
|
email = st.text_input(label='Email', placeholder='Enter Your Email') |
|
username = st.text_input(label='Username', placeholder='Enter Your Username') |
|
password = st.text_input(label='Password', placeholder='Enter Your Password', type='password') |
|
re_password = st.text_input(label='Confirm Password', placeholder='Confirm Your Password', type='password') |
|
|
|
|
|
if st.form_submit_button('Sign Up'): |
|
if not email or not username or not password or not re_password: |
|
st.error("Please fill out all fields.") |
|
elif not is_valid_email(email): |
|
st.error("Please enter a valid email address.") |
|
elif len(password) <= 3: |
|
st.error("Password too short") |
|
elif password != re_password: |
|
st.error("Passwords do not match. Please re-enter.") |
|
else: |
|
result = create_user(username, password, email) |
|
if result == 'username_exists': |
|
st.error("Username already registered. Please use a different username.") |
|
elif result == 'email_exists': |
|
st.error("Email already registered. Please use a different email.") |
|
elif result == 'success': |
|
st.success(f"Username {username} created successfully! Please login.") |
|
st.session_state.sign_up_successful = True |
|
else: |
|
st.error("Failed to create user. Please try again later.") |
|
|
|
if st.session_state.sign_up_successful: |
|
if st.button('Go to Login'): |
|
st.session_state.current_page = 'login' |
|
st.rerun() |
|
|
|
|
|
|
|
def reset_password_page(): |
|
with st.form(key='reset_password', clear_on_submit=True): |
|
st.subheader(':blue[Reset Password]') |
|
st.write("Enter your username and new password below.") |
|
|
|
|
|
username = st.text_input(label='Username', value='') |
|
new_password = st.text_input(label='New Password', type='password') |
|
re_password = st.text_input(label='Confirm New Password', type='password') |
|
|
|
if st.form_submit_button('Reset Password'): |
|
if not username: |
|
st.error("Please enter your username.") |
|
elif not username_exists(username): |
|
st.error("Username not found. Please enter a valid username.") |
|
elif not new_password or not re_password: |
|
st.error("Please fill out all fields.") |
|
elif len(new_password) <= 3: |
|
st.error("Password too short") |
|
elif new_password != re_password: |
|
st.error("Passwords do not match. Please re-enter.") |
|
else: |
|
reset_password(username, new_password) |
|
st.success("Password reset successfully. Please login with your new password.") |
|
st.session_state.current_page = 'login' |
|
|
|
|
|
st.write('Return to Login page') |
|
if st.button('Login'): |
|
st.session_state.current_page = 'login' |
|
st.rerun() |
|
|
|
|
|
|
|
|
|
html_content = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
display: flex; |
|
align-items: center; |
|
height: 100vh; |
|
margin: 0; |
|
background-color: #0E1117; |
|
color: #ffffff; |
|
} |
|
.title { |
|
font-size: 3rem; |
|
} |
|
.animated { |
|
display: inline-block; |
|
background: linear-gradient(90deg, #ff5733, #33ff57, #3357ff, #ff33a1); |
|
background-size: 400% 400%; |
|
-webkit-background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
animation: gradient 8s ease infinite; |
|
} |
|
@keyframes gradient { |
|
0% { |
|
background-position: 0% 50%; |
|
} |
|
50% { |
|
background-position: 100% 50%; |
|
} |
|
100% { |
|
background-position: 0% 50%; |
|
} |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<h2 class="title"> |
|
<span class="animated">GUVI GPT - Text Generator</span> |
|
</h2> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
|
|
ticker_style = """ |
|
<style> |
|
.ticker-wrap { |
|
overflow: hidden; |
|
position: relative; |
|
box-sizing: border-box; |
|
padding: 10px; |
|
background-color: #0E1117; |
|
color: #FAFAFA; |
|
font-size: 18px; |
|
white-space: nowrap; |
|
} |
|
.ticker-item { |
|
display: inline-block; |
|
padding-right: 30px; |
|
animation: ticker-slide 15s linear infinite; |
|
} |
|
@keyframes ticker-slide { |
|
0% { |
|
transform: translateX(100%); |
|
} |
|
100% { |
|
transform: translateX(-100%); |
|
} |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
img='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExOGttcXJnNXZ5azl5bnNrbDk1bnZ4eTJkeWNtbGJhM2I4ZDZheWplMiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/5k5vZwRFZR5aZeniqb/giphy.webp' |
|
|
|
disclaimer_message = "Disclaimer: Model data sourced from various web articles. Performance may vary based on data quality and relevance." |
|
|
|
|
|
|
|
|
|
model_name_or_path = "./fine_tuned_model" |
|
model = GPT2LMHeadModel.from_pretrained(model_name_or_path) |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
|
|
def generate_text(model, tokenizer, seed_text, max_length=100, temperature=0.01, num_return_sequences=1): |
|
|
|
inputs = tokenizer(seed_text, return_tensors='pt', padding=True, truncation=True) |
|
|
|
input_ids = inputs['input_ids'].to(device) |
|
attention_mask = inputs['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
temperature=temperature, |
|
num_return_sequences=num_return_sequences, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.1, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
|
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
for word in generated_text.split(): |
|
yield word + " " |
|
time.sleep(0.1) |
|
|
|
|
|
|
|
def home_page(): |
|
|
|
with st.sidebar: |
|
|
|
st.title(f"Welcome, {st.session_state.username}!") |
|
st.image(img,use_column_width=True) |
|
|
|
st.markdown('<br>',unsafe_allow_html=True) |
|
|
|
st.write("### Example Prompts") |
|
st.markdown(''' Guvi is an <br> Founders of guvi''',unsafe_allow_html=True) |
|
|
|
st.markdown('<br>',unsafe_allow_html=True) |
|
|
|
max=st.slider('Select MAX words',10,250) |
|
|
|
if st.button("Logout"): |
|
st.session_state.clear() |
|
st.session_state.current_page = 'login' |
|
st.rerun() |
|
|
|
|
|
|
|
|
|
st.markdown(html_content, unsafe_allow_html=True) |
|
st.markdown(ticker_style, unsafe_allow_html=True) |
|
st.markdown(f'<div class="ticker-wrap"><div class="ticker-item">{disclaimer_message}</div></div>', unsafe_allow_html=True) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if prompt := st.chat_input("What up?"): |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("assistant"): |
|
|
|
response = st.write_stream(generate_text(model, tokenizer, seed_text=prompt, max_length=max, temperature=0.01, num_return_sequences=1)) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.current_page == 'home': |
|
home_page() |
|
elif st.session_state.current_page == 'login': |
|
login() |
|
elif st.session_state.current_page == 'sign_up': |
|
signup() |
|
elif st.session_state.current_page == 'reset_password': |
|
reset_password_page() |