zaidmehdi commited on
Commit
7ecfa8c
·
1 Parent(s): c1192ba

replace flask api with gradio app

Browse files
Files changed (1) hide show
  1. src/main.py +13 -21
src/main.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
  import pickle
3
 
4
- from flask import Flask, request, jsonify
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  from .utils import extract_hidden_state
8
 
9
- app = Flask(__name__)
10
 
11
  models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
12
  model_file = os.path.join(models_dir, 'logistic_regression.pkl')
@@ -21,26 +20,19 @@ model_name = "moussaKam/AraBART"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  language_model = AutoModel.from_pretrained(model_name)
23
 
 
 
 
 
 
24
 
25
- @app.route("/classify", methods=["POST"])
26
- def classify_arabic_dialect():
27
- try:
28
- data = request.json
29
- text = data.get("text")
30
- if not text:
31
- return jsonify({"error": "No text has been received"}), 400
32
-
33
- text_embeddings = extract_hidden_state(text, tokenizer, language_model)
34
- predicted_class = model.predict(text_embeddings)[0]
35
-
36
- return jsonify({"class": predicted_class}), 200
37
- except Exception as e:
38
- return jsonify({"error": str(e)}), 500
39
-
40
-
41
- def main():
42
- app.run(host="0.0.0.0", port=5000)
43
 
44
 
45
  if __name__ == "__main__":
46
- main()
 
 
1
  import os
2
  import pickle
3
 
4
+ import gradio as gr
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  from .utils import extract_hidden_state
8
 
 
9
 
10
  models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
11
  model_file = os.path.join(models_dir, 'logistic_regression.pkl')
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  language_model = AutoModel.from_pretrained(model_name)
22
 
23
+ def classify_arabic_dialect(text):
24
+ text_embeddings = extract_hidden_state(text, tokenizer, language_model)
25
+ predicted_class = model.predict(text_embeddings)[0]
26
+
27
+ return predicted_class
28
 
29
+ demo = gr.Interface(
30
+ fn=classify_arabic_dialect,
31
+ inputs=["text"],
32
+ outputs=["text"],
33
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  if __name__ == "__main__":
37
+ demo.launch()
38
+