MSP RAJA commited on
Commit
f784d15
1 Parent(s): 526af86
Files changed (3) hide show
  1. app.py +81 -0
  2. config.py +9 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from flask import Flask, request, jsonify
3
+ import os
4
+ from wtforms import Form, StringField
5
+ from wtforms.validators import DataRequired
6
+ from config import model_ckpt, pipe, labels
7
+
8
+ app = Flask(__name__)
9
+
10
+ # # configure logging
11
+ # logging.basicConfig(
12
+ # filename='app.log',
13
+ # level=logging.INFO,
14
+ # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
+ # )
16
+ # logger = logging.getLogger(__name__)
17
+
18
+ class PredictForm(Form):
19
+ text = StringField('text', [DataRequired()])
20
+
21
+ def predict(text: str) -> dict:
22
+ """
23
+ Compute predictions for text.
24
+ :param text: str : The text to be analyzed.
25
+ :return: dict : A dictionary of predicted language and its score
26
+ """
27
+ try:
28
+ preds = pipe(text, return_all_scores=True, truncation=True, max_length=128)
29
+ if preds:
30
+ pred = preds[0]
31
+ pred = sorted(pred, key=lambda x: x['score'], reverse=True)
32
+ return {labels.get(p["label"],p["label"]): float(p["score"]) for p in pred[:1]}
33
+ else:
34
+ return {}
35
+ except Exception as e:
36
+ logger.error("Error processing request: %s", str(e))
37
+ return {'error': str(e)}, 500
38
+
39
+ @app.route('/language', methods=['POST'])
40
+ def predict_language():
41
+ """
42
+ A Language Prediction API which accepts 'text' as input and return the language of text along with score
43
+ ---
44
+ parameters:
45
+ - in: body
46
+ name: text
47
+ schema:
48
+ type: string
49
+ required: true
50
+ description: The text to be analyzed
51
+ responses:
52
+ 200:
53
+ description: A JSON object containing the language and its score
54
+ schema:
55
+ type: object
56
+ 400:
57
+ description: Invalid request
58
+ 500:
59
+ description: Internal server error
60
+ """
61
+ # form = PredictForm(request.form)
62
+ # if form.validate():
63
+ text = request.json['text']
64
+ if not text:
65
+ return jsonify({'error': 'Empty text provided'}), 400
66
+
67
+ result = predict(text)
68
+ if result:
69
+ return jsonify(result)
70
+ else:
71
+ return jsonify({'error': 'No predictions found'}), 400
72
+ # else:
73
+ # return jsonify({'error': 'Invalid input provided'}), 400
74
+
75
+ if __name__ == '__main__':
76
+ log_file = 'app.log'
77
+ logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
78
+ logger = logging.getLogger(__name__)
79
+ logger.info("Running the app...")
80
+ app.run()
81
+
config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ model_ckpt = "papluca/xlm-roberta-base-language-detection"
4
+ pipe = pipeline("text-classification", model=model_ckpt)
5
+
6
+ labels = {"ar" : "Arabic", "bg" : "Bulgarian", "de" : "German", "el" : "Modern Greek",
7
+ "en" : "English", "es" : "Spanish", "fr" : "French", "hi" : "Hindi", "it" : "Italian",
8
+ "ja" : "Japanese", "nl" : "Dutch", "pl" : "Polish", "pt" : "Portuguese", "ru" : "Russian",
9
+ "sw" : "Swahili", "th" : "Thai", "tr" : "Turkish", "ur" : "Urdu", "vi" : "Vietnamese", "zh" : "Chinese"}
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.13.1
2
+ WTForms==3.0.1
3
+ transformers==4.25.1
4
+ Flask==2.2.2