zaidmehdi commited on
Commit
6893bb5
1 Parent(s): f838a8b

taking value of predicted_class from list output of log reg prediction

Browse files
Files changed (1) hide show
  1. src/main.py +3 -3
src/main.py CHANGED
@@ -22,10 +22,10 @@ def classify_arabic_dialect():
22
  text = data.get("text")
23
  if not text:
24
  return jsonify({"error": "No text has been received"}), 400
25
-
26
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
27
- predicted_class = model.predict(text_embeddings)
28
-
29
  return jsonify({"class": predicted_class}), 200
30
  except Exception as e:
31
  return jsonify({"error": str(e)}), 500
 
22
  text = data.get("text")
23
  if not text:
24
  return jsonify({"error": "No text has been received"}), 400
25
+
26
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
27
+ predicted_class = model.predict(text_embeddings)[0]
28
+
29
  return jsonify({"class": predicted_class}), 200
30
  except Exception as e:
31
  return jsonify({"error": str(e)}), 500