Spaces:
Running
Running
import os | |
from flask import * | |
from routes.helpers import configFile | |
from routes import * | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
#initing | |
app = Flask(__name__) | |
VERSION = '1.0 build123' | |
app.config['JSON_AS_ASCII'] = False | |
#error pages | |
def ratelimit_handler(e): return render_template('ratelimit.html') | |
def forbidden_handler(e): return render_template('forbidden.html') | |
def notfound_handler(e): return render_template('notfound.html') | |
def methodnotallowed_handler(e): return render_template('methodnotallowed.html') | |
def internalservererror_handler(e): return render_template('intervalservererror.html') | |
def badgateway_handler(e): return render_template('badgateway.html') | |
#empty routes | |
def emptyPath(): return {} | |
def emptyApiWA(path): return {"status": "error", "error_code": 100, "error_details": "No method like that found"} | |
#icon | |
def favicon(): return send_from_directory(os.path.join(app.root_path, 'static'), 'favicon.ico', mimetype='image/vnd.microsoft.icon') | |
############### | |
#SITE ROUTES | |
def index(): return render_template('index.html') | |
def systemInfo(): return siteRoutes.systemInfo() | |
############### | |
#YT SOUND API | |
def search(): return ytApi.search(request) | |
def getFull(): return ytApi.getFull(request) | |
def getPreview(): return ytApi.getPreview(request) | |
############### | |
#JOKES API | |
def getJoke(): return jokes.getJoke(request) | |
def getJokesSources(): return jokes.getSources(request) | |
############### | |
#HOLIDAYS API (dont wanna document it) | |
def getHolidays(): return holidays.getHolidays(request) | |
############### | |
#OSU API | |
def findSong(): return osuApi.findSong(request) | |
def getBeatmap(): return osuApi.getBeatmap(request) | |
def getBMPreview(): return osuApi.getPreview(request) | |
def getBMFull(): return osuApi.getFull(request) | |
############### | |
# LOAD MODELS | |
sa_t, sa_m = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment"), AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment") | |
tc_t, tc_m = AutoTokenizer.from_pretrained("EIStakovskii/xlm_roberta_base_multilingual_toxicity_classifier_plus"), AutoModelForSequenceClassification.from_pretrained("EIStakovskii/xlm_roberta_base_multilingual_toxicity_classifier_plus") | |
############## | |
# ANALYZE DATA API | |
# to understand which text is negative, positive or neutral | |
def sentimentAnalys(): | |
try: | |
text = helpers.getFromRequest(request, "text") | |
if not text: return {"status": "error", "details": { "error_code": 101, "error_details": "No text provided" }} | |
inputs = sa_t(text, return_tensors="pt") | |
outputs = sa_m(**inputs) | |
logits = outputs.logits | |
predicted_sentiment_index = logits.argmax(dim=1).item() | |
predicted_sentiment = sa_m.config.id2label[predicted_sentiment_index] | |
return {"status": "pass", "predicted_sentiment": predicted_sentiment} | |
except Exception as e: return {"status": "error", "details": { "error_code": 123, "error_details": str(e).replace("\n", " | ") }} | |
def toxicityAnalys(): | |
try: | |
text = helpers.getFromRequest(request, "text") | |
if not text: return {"status": "error", "details": { "error_code": 101, "error_details": "No text provided" }} , 400 | |
inputs = tc_t(text, return_tensors="pt") | |
outputs = tc_m(**inputs) | |
logits = outputs.logits | |
predicted_class = logits.argmax(dim=1).item() | |
predicted_sentiment = True if str(tc_m.config.id2label[predicted_class]) == "LABEL_1" else False | |
return {"status": "pass", "toxicity": predicted_sentiment} | |
except Exception as e: return {"status": "error", "details": { "error_code": 123, "error_details": str(e).replace("\n", " | ") }} , 400 | |
if __name__ == "__main__": | |
config = configFile() | |
with open(config['config-path'], "w", encoding="utf-8") as outfile: | |
config['buildVersion'] = VERSION | |
json.dump(config, outfile) | |
with open(config['openapi-yaml-path'], "r+", encoding="utf-8") as outfile: | |
info = outfile.read() | |
outfile.seek(0) | |
outfile.write(info.replace('$VERSION_VARIABLE$', VERSION)) | |
outfile.truncate() | |
app.run(host="0.0.0.0", port=7860) |