funapi / app.py
imperialwool's picture
Update app.py
7b90833
raw
history blame
7.45 kB
import os
from flask import *
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from routes.helpers import checkSignature, configFile
from routes import *
from random import randint
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
#initing
app = Flask(__name__)
VERSION = '1.0 build117'
app.config['JSON_AS_ASCII'] = False
limiter = Limiter(app=app, key_func=get_remote_address, default_limits=["5/minute"], storage_uri="memory://",)
#limiter
@limiter.request_filter
def ip_whitelist():
#try:
# if request.method == 'POST': signature = request.form['signature']
# else: signature = request.args['signature']
# return checkSignature(signature)
#except: return False
return bool(randint(0,1))
#error pages
@app.errorhandler(429)
def ratelimit_handler(e): return render_template('ratelimit.html')
@app.errorhandler(403)
def forbidden_handler(e): return render_template('forbidden.html')
@app.errorhandler(404)
def ratelimit_handler(e): return render_template('notfound.html')
@app.errorhandler(500)
def ratelimit_handler(e): return render_template('intervalservererror.html')
@app.errorhandler(502)
def ratelimit_handler(e): return render_template('badgateway.html')
#empty routes
@app.route('/yt/api/v1', methods=['GET', 'POST'])
@app.route('/recognize/api/v1', methods=['GET', 'POST'])
@app.route('/osu/api/v1', methods=['GET', 'POST'])
def emptyPath(): return {}
@app.route('/yt/api/v1/<path:path>', methods=['GET', 'POST'])
@app.route('/recognize/api/v1/<path:path>', methods=['GET', 'POST'])
@app.route('/osu/api/v1/<path:path>', methods=['GET', 'POST'])
def emptyApiWA(path): return {"status": "error", "error_code": 100, "error_details": "No method like that found"}
#icon
@app.route('/favicon.ico')
@limiter.exempt
def favicon(): return send_from_directory(os.path.join(app.root_path, 'static'), 'favicon.ico', mimetype='image/vnd.microsoft.icon')
###############
#SITE ROUTES
@app.route('/')
@limiter.exempt
def index(): return render_template('index.html')
@app.route('/signatures/api/v1/get', methods=['GET', 'POST'])
@limiter.exempt
def signatureGen(): return siteRoutes.signatureGen(request)
@app.route('/system-info/api/v1/get', methods=['GET', 'POST'])
@limiter.exempt
def systemInfo(): return siteRoutes.systemInfo()
###############
#RECOGNIZE API
@app.route('/recognize/api/v1/voice', methods=['GET', 'POST'])
def recognizeVoice(): return recognizeApi.recognizeVoice(request)
###############
#YT SOUND API
@app.route('/yt/api/v1/search', methods=['GET', 'POST'])
def search(): return ytApi.search(request)
@app.route('/yt/api/v1/get-full', methods=['GET', 'POST'])
def getFull(): return ytApi.getFull(request)
@app.route('/yt/api/v1/get-preview', methods=['GET', 'POST'])
def getPreview(): return ytApi.getPreview(request)
###############
#OSU API
@app.route('/osu/api/v1/find-song', methods=['GET', 'POST'])
def findSong(): return osuApi.findSong(request)
@app.route('/osu/api/v1/get-beatmap', methods=['GET', 'POST'])
def getBeatmap(): return osuApi.getBeatmap(request)
@app.route('/osu/api/v1/get-preview', methods=['GET', 'POST'])
def getBMPreview(): return osuApi.getPreview(request)
@app.route('/osu/api/v1/get-full', methods=['GET', 'POST'])
def getBMFull(): return osuApi.getFull(request)
###############
# LOAD MODELS
sa_t, sa_m = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment"), AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment")
tc_t, tc_m = AutoTokenizer.from_pretrained("EIStakovskii/xlm_roberta_base_multilingual_toxicity_classifier_plus"), AutoModelForSequenceClassification.from_pretrained("EIStakovskii/xlm_roberta_base_multilingual_toxicity_classifier_plus")
chct_t, chct_m = AutoTokenizer.from_pretrained("cointegrated/rut5-small-chitchat"), AutoModelForSeq2SeqLM.from_pretrained("cointegrated/rut5-small-chitchat")
##############
# ANALYZE DATA API
# to understand which text is negative, positive or neutral
@app.route('/analyzeText/api/v1/sentiment', methods=['GET', 'POST'])
def sentimentAnalys():
try:
text = request.form.get('text') or request.args.get('text') or request.values.get('text') or ""
if text == "":
try: text = request.json.get('text') or ""
except: pass
if text == "": return {"status": "error", "details": { "error_code": 101, "error_details": "No text provided" }}
inputs = sa_t(text, return_tensors="pt")
# Предсказание тональности текста
outputs = sa_m(**inputs)
logits = outputs.logits
predicted_sentiment_index = logits.argmax(dim=1).item()
predicted_sentiment = sa_m.config.id2label[predicted_sentiment_index]
return {"status": "pass", "predicted_sentiment": predicted_sentiment}
except Exception as e: return {"status": "error", "details": { "error_code": 123, "error_details": str(e).replace("\n", " | ") }}
@app.route('/analyzeText/api/v1/toxicity', methods=['GET', 'POST'])
def toxicityAnalys():
try:
text = request.form.get('text') or request.args.get('text') or request.values.get('text') or ""
if text == "":
try: text = request.json.get('text') or ""
except: pass
if text == "": return {"status": "error", "details": { "error_code": 101, "error_details": "No text provided" }} , 400
inputs = tc_t(text, return_tensors="pt")
# Предсказание тональности текста
outputs = tc_m(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(dim=1).item()
predicted_sentiment = True if str(tc_m.config.id2label[predicted_class]) == "LABEL_1" else False
return {"status": "pass", "toxicity": predicted_sentiment}
except Exception as e: return {"status": "error", "details": { "error_code": 123, "error_details": str(e).replace("\n", " | ") }} , 400
@app.route('/analyzeText/api/v1/chitchat', methods=['POST'])
def chitchatRu():
try:
text = request.form.get('text') or request.args.get('text') or request.values.get('text') or ""
if text == "":
try: text = request.json.get('text') or ""
except: pass
if text == "":
return {"status": "error", "details": {"error_code": 101, "error_details": "No text provided"}}, 400
inputs = chct_t.encode(text, padding=True, truncation=True, return_tensors="pt")
generated_ids = chct_m.generate(
input_ids=inputs,
use_cache=False
)
answer = chct_t.decode(generated_ids[0], skip_special_tokens=True)
return {"status": "pass", "answer": answer}, 200
except Exception as e:
return {"status": "error", "details": {"error_code": 123, "error_details": str(e).replace("\n", " | ")}}, 400
if __name__ == "__main__":
config = configFile()
with open(config['config-path'], "w") as outfile:
config['buildVersion'] = VERSION
json.dump(config, outfile)
with open(config['openapi-yaml-path'], "r+") as outfile:
info = outfile.read()
outfile.seek(0)
outfile.write(info.replace('$VERSION_VARIABLE$', VERSION))
outfile.truncate()
app.run(host="0.0.0.0", port=7860)