SimpleAES / create_app.py
SFM2001's picture
gix
75fb515
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
import os
from werkzeug.middleware.proxy_fix import ProxyFix
db = SQLAlchemy()
login_manager = LoginManager()
MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None
TOKEN_STORE = {}
LOGGED = False
def load_models():
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not MODELS_LOADED:
print("DEVICE", device)
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
LONGFORMER_MODEL.eval()
model_name = 'Qwen/Qwen3-1.7B'
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).half()
QWEN_MODEL = QWEN_MODEL.to(device)
MODELS_LOADED = True
def create_app():
set_seed(42)
app = Flask(__name__)
app.secret_key = "super-secret"
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///users.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['WTF_CSRF_ENABLED'] = False
# app.config['SESSION_COOKIE_SECURE'] = True
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
load_models()
db.init_app(app)
login_manager.login_view = 'auth.login'
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
with app.app_context():
from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
app.register_blueprint(auth_bp)
app.register_blueprint(dashboard_bp)
app.register_blueprint(infer_bp)
app.register_blueprint(about_bp)
app.register_blueprint(error_bp)
@app.errorhandler(Exception)
def handle_all_exceptions(e):
code = getattr(e, 'code', 500)
error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
return render_template('error.html', code=code, error_message=error_message), code
from database import User, History
db.create_all()
return app