Th3BossC commited on
Commit
c5b13e5
1 Parent(s): 9aa7535
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/__pycache__/
2
+ .venv/
ChitChat/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask
2
+ from flask_sqlalchemy import SQLAlchemy
3
+ from flask_bcrypt import Bcrypt
4
+ from flask_cors import CORS
5
+ from ChitChat.config import Config
6
+
7
+
8
+ db = SQLAlchemy()
9
+ bcrypt = Bcrypt()
10
+
11
+ def create_app(config_class = Config):
12
+ app = Flask(__name__)
13
+ CORS(app)
14
+ app.config.from_object(Config)
15
+ db.init_app(app)
16
+ bcrypt.init_app(app)
17
+
18
+ from ChitChat.resources.routes import resources
19
+ app.register_blueprint(resources)
20
+
21
+ return app
ChitChat/common/__init__.py ADDED
File without changes
ChitChat/common/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ from flask import current_app
4
+ from ChitChat import db
5
+ from ChitChat.models import ChatHistory
6
+
7
+ model_name = 'Th3BossC/DialoGPT-medium-AICLUB_NITC'
8
+ default_model = 'microsoft/DialoGPT-medium'
9
+
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(default_model)
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+
14
+
15
+ def getChatHistory(user):
16
+ if user.history is None:
17
+ return None
18
+ else:
19
+ return torch.load(user.history)
20
+
21
+ def saveChatHistory(user, chat_history_ids):
22
+ location = current_app.config.get('SAVE_FOLDER')
23
+ file_name = location + str(user.id) + '.pt'
24
+ if chat_history_ids.shape[-1] > 100:
25
+ user.history = None
26
+ db.session.commit()
27
+ else:
28
+ torch.save(chat_history_ids, file_name)
29
+ if user.history is None:
30
+ user.history = file_name
31
+ db.session.commit()
32
+
33
+
34
+ def conversation(user, userInput):
35
+ chat_history_ids = getChatHistory(user)
36
+ # print(chat_history_ids)
37
+ user_input_ids = tokenizer.encode(userInput + tokenizer.eos_token, return_tensors = "pt")
38
+ # print(user_input_ids)
39
+ bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
40
+ # print(bot_input_ids)
41
+ chat_history_ids = model.generate(
42
+ bot_input_ids,
43
+ max_length = 500,
44
+ no_repeat_ngram_size = 3,
45
+ do_sample = True,
46
+ top_k = 100,
47
+ top_p = 0.7,
48
+ temperature = 0.8
49
+ )
50
+ # print(f"chat_history_ids : {type(chat_history_ids)}")
51
+ saveChatHistory(user, chat_history_ids)
52
+ return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens = True)
ChitChat/config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class Config:
2
+ SECRET_KEY = '7a2b25ca707a5be465f9a8894f528999'
3
+ SQLALCHEMY_DATABASE_URI = 'sqlite:///site.db'
4
+ SAVE_FOLDER = 'ChitChat/common/files/'
ChitChat/models.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ChitChat import db
2
+ from datetime import datetime
3
+
4
+ class ChatHistory(db.Model):
5
+ id = db.Column(db.Integer, primary_key = True)
6
+ username = db.Column(db.String(10), unique = True, nullable = False)
7
+ email = db.Column(db.String(100), nullable = False)
8
+ password = db.Column(db.String(60), nullable = False)
9
+ date = db.Column(db.DateTime, nullable = False, default = datetime.utcnow)
10
+ history = db.Column(db.String(100), nullable = True)
11
+
12
+ def __repr__(self):
13
+ return f"ChatHistory({self.username}, {self.email})"
ChitChat/resources/__init__.py ADDED
File without changes
ChitChat/resources/routes.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Blueprint, request, current_app
2
+ from flask_restful import Api, Resource
3
+ from ChitChat.models import ChatHistory
4
+ from ChitChat import bcrypt, db
5
+ from ChitChat.common.utils import conversation
6
+
7
+
8
+ resources = Blueprint('resources', __name__)
9
+ api = Api(resources)
10
+
11
+
12
+ class UserLogin(Resource):
13
+ def post(self):
14
+ userInfo = request.json
15
+
16
+ user = ChatHistory.query.filter_by(email = userInfo['email']).first()
17
+
18
+ if user is None:
19
+ return {'status' : "Account doesn't exist", 'user_id' : -1}
20
+ elif bcrypt.check_password_hash(pw_hash = user.password, password = userInfo['password']):
21
+ return {'status' : "login successful", 'user_id' : user.id}
22
+ else:
23
+ return {'status' : "Invalid password", 'user_id' : -1}
24
+ api.add_resource(UserLogin, '/login/')
25
+
26
+ class RegisterUser(Resource):
27
+ def post(self):
28
+ userInfo = request.json
29
+
30
+ user = ChatHistory.query.filter_by(email = userInfo['email']).first()
31
+ if user is not None:
32
+ return {'status' : 'email already registered'}
33
+ else:
34
+ user = ChatHistory.query.filter_by(username = userInfo['username']).first()
35
+ if user is not None:
36
+ return {'status' : 'Username already exists'}
37
+
38
+ newUser = ChatHistory(username = userInfo['username'], email = userInfo['email'], password = bcrypt.generate_password_hash(password = userInfo['password']))
39
+ db.session.add(newUser)
40
+ db.session.commit()
41
+ return {'status' : 'User created successfully'}
42
+ api.add_resource(RegisterUser, '/register/')
43
+
44
+ class ChatBot(Resource):
45
+ def post(self, user_id):
46
+ user = ChatHistory.query.filter_by(id = user_id).first()
47
+ userInput = request.json['user']
48
+ if user is None:
49
+ return {'error' : "User doesn't exist"}, 400
50
+
51
+ reply = conversation(user, userInput)
52
+ return {'reply' : reply}
53
+ api.add_resource(ChatBot, '/chat/<int:user_id>')
54
+
55
+
56
+
57
+
58
+
DockerFile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ RUN useradd -m -u 1000 user
13
+ USER user
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+
20
+ COPY --chown=user . $HOME/app
21
+
22
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ChitChat import create_app
2
+
3
+
4
+ app = create_app()
5
+
6
+ if __name__ == '__main__':
7
+ app.run(debug = True, port = 5000)
8
+
9
+ # if __name__ == '__main__':
10
+ # app.run(debug = False, host = "0.0.0.0", port = 7860)
instance/site.db ADDED
Binary file (12.3 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aniso8601==9.0.1
2
+ bcrypt==4.0.1
3
+ blinker==1.6.2
4
+ certifi==2023.5.7
5
+ charset-normalizer==3.1.0
6
+ click==8.1.3
7
+ colorama==0.4.6
8
+ filelock==3.12.2
9
+ Flask==2.3.2
10
+ Flask-Bcrypt==1.0.1
11
+ Flask-Cors==3.0.10
12
+ Flask-RESTful==0.3.10
13
+ Flask-SQLAlchemy==3.0.3
14
+ fsspec==2023.6.0
15
+ greenlet==2.0.2
16
+ huggingface-hub==0.15.1
17
+ idna==3.4
18
+ itsdangerous==2.1.2
19
+ Jinja2==3.1.2
20
+ MarkupSafe==2.1.3
21
+ mpmath==1.3.0
22
+ networkx==3.1
23
+ numpy==1.24.3
24
+ packaging==23.1
25
+ pytz==2023.3
26
+ PyYAML==6.0
27
+ regex==2023.6.3
28
+ requests==2.31.0
29
+ safetensors==0.3.1
30
+ six==1.16.0
31
+ SQLAlchemy==2.0.16
32
+ sympy==1.12
33
+ tokenizers==0.13.3
34
+ torch==2.0.1
35
+ tqdm==4.65.0
36
+ transformers==4.30.2
37
+ typing_extensions==4.6.3
38
+ urllib3==2.0.3
39
+ Werkzeug==2.3.6