lang_id_flask / app.py
MSP RAJA
fixed output
a30b891
raw
history blame contribute delete
No virus
2.51 kB
import logging
from flask import Flask, request, jsonify
import os
from wtforms import Form, StringField
from wtforms.validators import DataRequired
from config import model_ckpt, pipe, labels, THRESHOLD
app = Flask(__name__)
class PredictForm(Form):
text = StringField('text', [DataRequired()])
def predict(text: str) -> dict:
"""
Compute predictions for text.
:param text: str : The text to be analyzed.
:return: dict : A dictionary of predicted language and its score
"""
try:
preds = pipe(text, return_all_scores=True, truncation=True, max_length=128)
if preds:
pred = preds[0]
pred = sorted(pred, key=lambda x: x['score'], reverse=True)
if pred[0]["score"] > THRESHOLD:
return {labels.get(p["label"],p["label"]): float(p["score"]) for p in pred[:1]}
else:
score = pred[0]["score"]
logger.error("Prediction score below threshold. text: %s, score: %s", text, score)
return {'error': "Prediction score below threshold"}
else:
return {}
except Exception as e:
logger.error("Error processing request: %s", str(e))
return {'error': str(e)}, 500
@app.route('/language', methods=['POST'])
def predict_language():
"""
A Language Prediction API which accepts 'text' as input and return the language of text along with score
---
parameters:
- in: body
name: text
schema:
type: string
required: true
description: The text to be analyzed
responses:
200:
description: A JSON object containing the language and its score
schema:
type: object
400:
description: Invalid request
500:
description: Internal server error
400:
description: Prediction score below threshold
"""
text = request.json.get('text')
if not text or len(text)==0:
return jsonify({'error': 'Empty text provided'}), 400
result = predict(text)
if result:
return jsonify(result)
else:
return jsonify({'error': 'No predictions found'}), 400
if __name__ == '__main__':
log_file = 'app.log'
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.info("Running the app...")
app.run()