|
from flask import Flask |
|
from flask_sqlalchemy import SQLAlchemy |
|
from flask_login import LoginManager |
|
from flask import render_template |
|
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig |
|
from models.model import * |
|
from utils.util_func import * |
|
import os |
|
from werkzeug.middleware.proxy_fix import ProxyFix |
|
|
|
db = SQLAlchemy() |
|
login_manager = LoginManager() |
|
|
|
MODELS_LOADED = False |
|
LONGFORMER_TOKENIZER = None |
|
LONGFORMER_MODEL = None |
|
QWEN_TOKENIZER = None |
|
QWEN_MODEL = None |
|
MODEL_SESSION = None |
|
TOKEN_STORE = {} |
|
LOGGED = False |
|
|
|
def load_models(): |
|
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if not MODELS_LOADED: |
|
print("DEVICE", device) |
|
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto') |
|
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json") |
|
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer') |
|
LONGFORMER_MODEL = LONGFORMER_MODEL.to(device) |
|
LONGFORMER_MODEL.eval() |
|
|
|
model_name = 'Qwen/Qwen3-1.7B' |
|
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto') |
|
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id |
|
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).half() |
|
QWEN_MODEL = QWEN_MODEL.to(device) |
|
MODELS_LOADED = True |
|
|
|
def create_app(): |
|
set_seed(42) |
|
app = Flask(__name__) |
|
app.secret_key = "super-secret" |
|
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1) |
|
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///users.db' |
|
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False |
|
app.config['WTF_CSRF_ENABLED'] = False |
|
|
|
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' |
|
load_models() |
|
db.init_app(app) |
|
login_manager.login_view = 'auth.login' |
|
login_manager.init_app(app) |
|
@login_manager.user_loader |
|
def load_user(user_id): |
|
return User.query.get(int(user_id)) |
|
|
|
with app.app_context(): |
|
from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp |
|
app.register_blueprint(auth_bp) |
|
app.register_blueprint(dashboard_bp) |
|
app.register_blueprint(infer_bp) |
|
app.register_blueprint(about_bp) |
|
app.register_blueprint(error_bp) |
|
@app.errorhandler(Exception) |
|
def handle_all_exceptions(e): |
|
code = getattr(e, 'code', 500) |
|
error_message = str(e) if hasattr(e, 'description') else "Something went wrong." |
|
return render_template('error.html', code=code, error_message=error_message), code |
|
from database import User, History |
|
db.create_all() |
|
|
|
return app |