File size: 2,989 Bytes
4f591e5 5ef0691 b800d76 4f591e5 75fb515 4f591e5 b92108d 0daa81d 78b4c1f ed8bf25 b92108d 4f591e5 a5cd8cb 4f591e5 78b4c1f b800d76 5ef0691 80af8fd 46a92f6 b800d76 5ef0691 4f591e5 5ef0691 78b4c1f 4f591e5 33ff5ca 4f591e5 78b4c1f 4f591e5 78b4c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
import os
from werkzeug.middleware.proxy_fix import ProxyFix
db = SQLAlchemy()
login_manager = LoginManager()
MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None
TOKEN_STORE = {}
LOGGED = False
def load_models():
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not MODELS_LOADED:
print("DEVICE", device)
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
LONGFORMER_MODEL.eval()
model_name = 'Qwen/Qwen3-1.7B'
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).half()
QWEN_MODEL = QWEN_MODEL.to(device)
MODELS_LOADED = True
def create_app():
set_seed(42)
app = Flask(__name__)
app.secret_key = "super-secret"
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///users.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['WTF_CSRF_ENABLED'] = False
# app.config['SESSION_COOKIE_SECURE'] = True
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
load_models()
db.init_app(app)
login_manager.login_view = 'auth.login'
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
with app.app_context():
from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
app.register_blueprint(auth_bp)
app.register_blueprint(dashboard_bp)
app.register_blueprint(infer_bp)
app.register_blueprint(about_bp)
app.register_blueprint(error_bp)
@app.errorhandler(Exception)
def handle_all_exceptions(e):
code = getattr(e, 'code', 500)
error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
return render_template('error.html', code=code, error_message=error_message), code
from database import User, History
db.create_all()
return app |