File size: 2,989 Bytes
4f591e5
 
 
 
 
 
 
5ef0691
b800d76
4f591e5
 
 
 
 
 
 
 
 
 
75fb515
 
4f591e5
 
 
 
 
b92108d
0daa81d
 
 
 
 
78b4c1f
 
 
 
ed8bf25
b92108d
4f591e5
a5cd8cb
4f591e5
 
78b4c1f
b800d76
 
5ef0691
 
80af8fd
46a92f6
b800d76
5ef0691
4f591e5
 
5ef0691
78b4c1f
 
 
4f591e5
 
33ff5ca
4f591e5
 
 
 
 
78b4c1f
 
 
 
 
 
 
4f591e5
78b4c1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
import os 
from werkzeug.middleware.proxy_fix import ProxyFix

db = SQLAlchemy()
login_manager = LoginManager()

MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None
TOKEN_STORE = {}
LOGGED = False

def load_models():
    global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not MODELS_LOADED:
        print("DEVICE", device)
        LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
        config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
        LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
        LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
        LONGFORMER_MODEL.eval()
        
        model_name = 'Qwen/Qwen3-1.7B'
        QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
        QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
        QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).half()
        QWEN_MODEL = QWEN_MODEL.to(device)
        MODELS_LOADED = True

def create_app():
    set_seed(42)
    app = Flask(__name__)
    app.secret_key = "super-secret"
    app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1)
    app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///users.db'
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    app.config['WTF_CSRF_ENABLED'] = False
    # app.config['SESSION_COOKIE_SECURE'] = True
    app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
    load_models()
    db.init_app(app)
    login_manager.login_view = 'auth.login'
    login_manager.init_app(app)
    @login_manager.user_loader
    def load_user(user_id):
        return User.query.get(int(user_id))

    with app.app_context():
        from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
        app.register_blueprint(auth_bp)
        app.register_blueprint(dashboard_bp)
        app.register_blueprint(infer_bp)
        app.register_blueprint(about_bp)
        app.register_blueprint(error_bp)
        @app.errorhandler(Exception)
        def handle_all_exceptions(e):
            code = getattr(e, 'code', 500)
            error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
            return render_template('error.html', code=code, error_message=error_message), code
        from database import User, History
        db.create_all()
    
    return app