Spaces:
Runtime error
Runtime error
MSP RAJA
commited on
Commit
·
f784d15
1
Parent(s):
526af86
updated
Browse files- app.py +81 -0
- config.py +9 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
import os
|
4 |
+
from wtforms import Form, StringField
|
5 |
+
from wtforms.validators import DataRequired
|
6 |
+
from config import model_ckpt, pipe, labels
|
7 |
+
|
8 |
+
app = Flask(__name__)
|
9 |
+
|
10 |
+
# # configure logging
|
11 |
+
# logging.basicConfig(
|
12 |
+
# filename='app.log',
|
13 |
+
# level=logging.INFO,
|
14 |
+
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
15 |
+
# )
|
16 |
+
# logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
class PredictForm(Form):
|
19 |
+
text = StringField('text', [DataRequired()])
|
20 |
+
|
21 |
+
def predict(text: str) -> dict:
|
22 |
+
"""
|
23 |
+
Compute predictions for text.
|
24 |
+
:param text: str : The text to be analyzed.
|
25 |
+
:return: dict : A dictionary of predicted language and its score
|
26 |
+
"""
|
27 |
+
try:
|
28 |
+
preds = pipe(text, return_all_scores=True, truncation=True, max_length=128)
|
29 |
+
if preds:
|
30 |
+
pred = preds[0]
|
31 |
+
pred = sorted(pred, key=lambda x: x['score'], reverse=True)
|
32 |
+
return {labels.get(p["label"],p["label"]): float(p["score"]) for p in pred[:1]}
|
33 |
+
else:
|
34 |
+
return {}
|
35 |
+
except Exception as e:
|
36 |
+
logger.error("Error processing request: %s", str(e))
|
37 |
+
return {'error': str(e)}, 500
|
38 |
+
|
39 |
+
@app.route('/language', methods=['POST'])
|
40 |
+
def predict_language():
|
41 |
+
"""
|
42 |
+
A Language Prediction API which accepts 'text' as input and return the language of text along with score
|
43 |
+
---
|
44 |
+
parameters:
|
45 |
+
- in: body
|
46 |
+
name: text
|
47 |
+
schema:
|
48 |
+
type: string
|
49 |
+
required: true
|
50 |
+
description: The text to be analyzed
|
51 |
+
responses:
|
52 |
+
200:
|
53 |
+
description: A JSON object containing the language and its score
|
54 |
+
schema:
|
55 |
+
type: object
|
56 |
+
400:
|
57 |
+
description: Invalid request
|
58 |
+
500:
|
59 |
+
description: Internal server error
|
60 |
+
"""
|
61 |
+
# form = PredictForm(request.form)
|
62 |
+
# if form.validate():
|
63 |
+
text = request.json['text']
|
64 |
+
if not text:
|
65 |
+
return jsonify({'error': 'Empty text provided'}), 400
|
66 |
+
|
67 |
+
result = predict(text)
|
68 |
+
if result:
|
69 |
+
return jsonify(result)
|
70 |
+
else:
|
71 |
+
return jsonify({'error': 'No predictions found'}), 400
|
72 |
+
# else:
|
73 |
+
# return jsonify({'error': 'Invalid input provided'}), 400
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
log_file = 'app.log'
|
77 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
78 |
+
logger = logging.getLogger(__name__)
|
79 |
+
logger.info("Running the app...")
|
80 |
+
app.run()
|
81 |
+
|
config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
|
3 |
+
model_ckpt = "papluca/xlm-roberta-base-language-detection"
|
4 |
+
pipe = pipeline("text-classification", model=model_ckpt)
|
5 |
+
|
6 |
+
labels = {"ar" : "Arabic", "bg" : "Bulgarian", "de" : "German", "el" : "Modern Greek",
|
7 |
+
"en" : "English", "es" : "Spanish", "fr" : "French", "hi" : "Hindi", "it" : "Italian",
|
8 |
+
"ja" : "Japanese", "nl" : "Dutch", "pl" : "Polish", "pt" : "Portuguese", "ru" : "Russian",
|
9 |
+
"sw" : "Swahili", "th" : "Thai", "tr" : "Turkish", "ur" : "Urdu", "vi" : "Vietnamese", "zh" : "Chinese"}
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
WTForms==3.0.1
|
3 |
+
transformers==4.25.1
|
4 |
+
Flask==2.2.2
|