from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager from flask import render_template from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig from models.model import * from utils.util_func import * import os from werkzeug.middleware.proxy_fix import ProxyFix db = SQLAlchemy() login_manager = LoginManager() MODELS_LOADED = False LONGFORMER_TOKENIZER = None LONGFORMER_MODEL = None QWEN_TOKENIZER = None QWEN_MODEL = None MODEL_SESSION = None TOKEN_STORE = {} LOGGED = False def load_models(): global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not MODELS_LOADED: print("DEVICE", device) LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto') config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json") LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer') LONGFORMER_MODEL = LONGFORMER_MODEL.to(device) LONGFORMER_MODEL.eval() model_name = 'Qwen/Qwen3-1.7B' QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto') QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).half() QWEN_MODEL = QWEN_MODEL.to(device) MODELS_LOADED = True def create_app(): set_seed(42) app = Flask(__name__) app.secret_key = "super-secret" app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1) app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///users.db' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['WTF_CSRF_ENABLED'] = False # app.config['SESSION_COOKIE_SECURE'] = True app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' load_models() db.init_app(app) login_manager.login_view = 'auth.login' login_manager.init_app(app) @login_manager.user_loader def load_user(user_id): return User.query.get(int(user_id)) with app.app_context(): from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp app.register_blueprint(auth_bp) app.register_blueprint(dashboard_bp) app.register_blueprint(infer_bp) app.register_blueprint(about_bp) app.register_blueprint(error_bp) @app.errorhandler(Exception) def handle_all_exceptions(e): code = getattr(e, 'code', 500) error_message = str(e) if hasattr(e, 'description') else "Something went wrong." return render_template('error.html', code=code, error_message=error_message), code from database import User, History db.create_all() return app