trial run
Browse files- .gitignore +2 -0
- ChitChat/__init__.py +21 -0
- ChitChat/common/__init__.py +0 -0
- ChitChat/common/utils.py +52 -0
- ChitChat/config.py +4 -0
- ChitChat/models.py +13 -0
- ChitChat/resources/__init__.py +0 -0
- ChitChat/resources/routes.py +58 -0
- DockerFile +22 -0
- app.py +10 -0
- instance/site.db +0 -0
- requirements.txt +39 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
**/__pycache__/
|
2 |
+
.venv/
|
ChitChat/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask
|
2 |
+
from flask_sqlalchemy import SQLAlchemy
|
3 |
+
from flask_bcrypt import Bcrypt
|
4 |
+
from flask_cors import CORS
|
5 |
+
from ChitChat.config import Config
|
6 |
+
|
7 |
+
|
8 |
+
db = SQLAlchemy()
|
9 |
+
bcrypt = Bcrypt()
|
10 |
+
|
11 |
+
def create_app(config_class = Config):
|
12 |
+
app = Flask(__name__)
|
13 |
+
CORS(app)
|
14 |
+
app.config.from_object(Config)
|
15 |
+
db.init_app(app)
|
16 |
+
bcrypt.init_app(app)
|
17 |
+
|
18 |
+
from ChitChat.resources.routes import resources
|
19 |
+
app.register_blueprint(resources)
|
20 |
+
|
21 |
+
return app
|
ChitChat/common/__init__.py
ADDED
File without changes
|
ChitChat/common/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
from flask import current_app
|
4 |
+
from ChitChat import db
|
5 |
+
from ChitChat.models import ChatHistory
|
6 |
+
|
7 |
+
model_name = 'Th3BossC/DialoGPT-medium-AICLUB_NITC'
|
8 |
+
default_model = 'microsoft/DialoGPT-medium'
|
9 |
+
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(default_model)
|
12 |
+
tokenizer.pad_token = tokenizer.eos_token
|
13 |
+
|
14 |
+
|
15 |
+
def getChatHistory(user):
|
16 |
+
if user.history is None:
|
17 |
+
return None
|
18 |
+
else:
|
19 |
+
return torch.load(user.history)
|
20 |
+
|
21 |
+
def saveChatHistory(user, chat_history_ids):
|
22 |
+
location = current_app.config.get('SAVE_FOLDER')
|
23 |
+
file_name = location + str(user.id) + '.pt'
|
24 |
+
if chat_history_ids.shape[-1] > 100:
|
25 |
+
user.history = None
|
26 |
+
db.session.commit()
|
27 |
+
else:
|
28 |
+
torch.save(chat_history_ids, file_name)
|
29 |
+
if user.history is None:
|
30 |
+
user.history = file_name
|
31 |
+
db.session.commit()
|
32 |
+
|
33 |
+
|
34 |
+
def conversation(user, userInput):
|
35 |
+
chat_history_ids = getChatHistory(user)
|
36 |
+
# print(chat_history_ids)
|
37 |
+
user_input_ids = tokenizer.encode(userInput + tokenizer.eos_token, return_tensors = "pt")
|
38 |
+
# print(user_input_ids)
|
39 |
+
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
40 |
+
# print(bot_input_ids)
|
41 |
+
chat_history_ids = model.generate(
|
42 |
+
bot_input_ids,
|
43 |
+
max_length = 500,
|
44 |
+
no_repeat_ngram_size = 3,
|
45 |
+
do_sample = True,
|
46 |
+
top_k = 100,
|
47 |
+
top_p = 0.7,
|
48 |
+
temperature = 0.8
|
49 |
+
)
|
50 |
+
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
51 |
+
saveChatHistory(user, chat_history_ids)
|
52 |
+
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens = True)
|
ChitChat/config.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Config:
|
2 |
+
SECRET_KEY = '7a2b25ca707a5be465f9a8894f528999'
|
3 |
+
SQLALCHEMY_DATABASE_URI = 'sqlite:///site.db'
|
4 |
+
SAVE_FOLDER = 'ChitChat/common/files/'
|
ChitChat/models.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ChitChat import db
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
class ChatHistory(db.Model):
|
5 |
+
id = db.Column(db.Integer, primary_key = True)
|
6 |
+
username = db.Column(db.String(10), unique = True, nullable = False)
|
7 |
+
email = db.Column(db.String(100), nullable = False)
|
8 |
+
password = db.Column(db.String(60), nullable = False)
|
9 |
+
date = db.Column(db.DateTime, nullable = False, default = datetime.utcnow)
|
10 |
+
history = db.Column(db.String(100), nullable = True)
|
11 |
+
|
12 |
+
def __repr__(self):
|
13 |
+
return f"ChatHistory({self.username}, {self.email})"
|
ChitChat/resources/__init__.py
ADDED
File without changes
|
ChitChat/resources/routes.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Blueprint, request, current_app
|
2 |
+
from flask_restful import Api, Resource
|
3 |
+
from ChitChat.models import ChatHistory
|
4 |
+
from ChitChat import bcrypt, db
|
5 |
+
from ChitChat.common.utils import conversation
|
6 |
+
|
7 |
+
|
8 |
+
resources = Blueprint('resources', __name__)
|
9 |
+
api = Api(resources)
|
10 |
+
|
11 |
+
|
12 |
+
class UserLogin(Resource):
|
13 |
+
def post(self):
|
14 |
+
userInfo = request.json
|
15 |
+
|
16 |
+
user = ChatHistory.query.filter_by(email = userInfo['email']).first()
|
17 |
+
|
18 |
+
if user is None:
|
19 |
+
return {'status' : "Account doesn't exist", 'user_id' : -1}
|
20 |
+
elif bcrypt.check_password_hash(pw_hash = user.password, password = userInfo['password']):
|
21 |
+
return {'status' : "login successful", 'user_id' : user.id}
|
22 |
+
else:
|
23 |
+
return {'status' : "Invalid password", 'user_id' : -1}
|
24 |
+
api.add_resource(UserLogin, '/login/')
|
25 |
+
|
26 |
+
class RegisterUser(Resource):
|
27 |
+
def post(self):
|
28 |
+
userInfo = request.json
|
29 |
+
|
30 |
+
user = ChatHistory.query.filter_by(email = userInfo['email']).first()
|
31 |
+
if user is not None:
|
32 |
+
return {'status' : 'email already registered'}
|
33 |
+
else:
|
34 |
+
user = ChatHistory.query.filter_by(username = userInfo['username']).first()
|
35 |
+
if user is not None:
|
36 |
+
return {'status' : 'Username already exists'}
|
37 |
+
|
38 |
+
newUser = ChatHistory(username = userInfo['username'], email = userInfo['email'], password = bcrypt.generate_password_hash(password = userInfo['password']))
|
39 |
+
db.session.add(newUser)
|
40 |
+
db.session.commit()
|
41 |
+
return {'status' : 'User created successfully'}
|
42 |
+
api.add_resource(RegisterUser, '/register/')
|
43 |
+
|
44 |
+
class ChatBot(Resource):
|
45 |
+
def post(self, user_id):
|
46 |
+
user = ChatHistory.query.filter_by(id = user_id).first()
|
47 |
+
userInput = request.json['user']
|
48 |
+
if user is None:
|
49 |
+
return {'error' : "User doesn't exist"}, 400
|
50 |
+
|
51 |
+
reply = conversation(user, userInput)
|
52 |
+
return {'reply' : reply}
|
53 |
+
api.add_resource(ChatBot, '/chat/<int:user_id>')
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
DockerFile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
RUN useradd -m -u 1000 user
|
13 |
+
USER user
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
|
20 |
+
COPY --chown=user . $HOME/app
|
21 |
+
|
22 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ChitChat import create_app
|
2 |
+
|
3 |
+
|
4 |
+
app = create_app()
|
5 |
+
|
6 |
+
if __name__ == '__main__':
|
7 |
+
app.run(debug = True, port = 5000)
|
8 |
+
|
9 |
+
# if __name__ == '__main__':
|
10 |
+
# app.run(debug = False, host = "0.0.0.0", port = 7860)
|
instance/site.db
ADDED
Binary file (12.3 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aniso8601==9.0.1
|
2 |
+
bcrypt==4.0.1
|
3 |
+
blinker==1.6.2
|
4 |
+
certifi==2023.5.7
|
5 |
+
charset-normalizer==3.1.0
|
6 |
+
click==8.1.3
|
7 |
+
colorama==0.4.6
|
8 |
+
filelock==3.12.2
|
9 |
+
Flask==2.3.2
|
10 |
+
Flask-Bcrypt==1.0.1
|
11 |
+
Flask-Cors==3.0.10
|
12 |
+
Flask-RESTful==0.3.10
|
13 |
+
Flask-SQLAlchemy==3.0.3
|
14 |
+
fsspec==2023.6.0
|
15 |
+
greenlet==2.0.2
|
16 |
+
huggingface-hub==0.15.1
|
17 |
+
idna==3.4
|
18 |
+
itsdangerous==2.1.2
|
19 |
+
Jinja2==3.1.2
|
20 |
+
MarkupSafe==2.1.3
|
21 |
+
mpmath==1.3.0
|
22 |
+
networkx==3.1
|
23 |
+
numpy==1.24.3
|
24 |
+
packaging==23.1
|
25 |
+
pytz==2023.3
|
26 |
+
PyYAML==6.0
|
27 |
+
regex==2023.6.3
|
28 |
+
requests==2.31.0
|
29 |
+
safetensors==0.3.1
|
30 |
+
six==1.16.0
|
31 |
+
SQLAlchemy==2.0.16
|
32 |
+
sympy==1.12
|
33 |
+
tokenizers==0.13.3
|
34 |
+
torch==2.0.1
|
35 |
+
tqdm==4.65.0
|
36 |
+
transformers==4.30.2
|
37 |
+
typing_extensions==4.6.3
|
38 |
+
urllib3==2.0.3
|
39 |
+
Werkzeug==2.3.6
|