| # Template for API integration script for {{phase_name}} (using Flask example) | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch # Example PyTorch | |
| app = Flask(__name__) | |
| # --- Model and Tokenizer Loading --- | |
| model_name = "models/fine_tuned_model" # Replace with your actual model path | |
| tokenizer_name = "bert-base-uncased" # Replace with the tokenizer used for training, likely the base model tokenizer | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| print("Model and tokenizer loaded successfully.") | |
| model.eval() # Set model to evaluation mode | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| tokenizer = None | |
| model = None | |
| @app.route('/predict', methods=['POST']) | |
| def predict(): | |
| if not tokenizer or not model: | |
| return jsonify({"error": "Model or tokenizer not loaded."}), 500 | |
| try: | |
| data = request.get_json() | |
| text = data.get('text') | |
| if not text: | |
| return jsonify({"error": "No text input provided."}), 400 | |
| inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") # Tokenize input text | |
| with torch.no_grad(): # Inference mode | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_id = torch.argmax(logits, dim=-1).item() # Get predicted class | |
| # --- Map class ID to label (if applicable) --- | |
| # Example for binary classification (class 0 and 1) | |
| labels = ["Negative", "Positive"] # Replace with your actual labels | |
| predicted_label = labels[predicted_class_id] if predicted_class_id < len(labels) else f"Class {predicted_class_id}" | |
| return jsonify({"prediction": predicted_label, "class_id": predicted_class_id}) | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return jsonify({"error": "Error during prediction."}), 500 | |
| @app.route('/', methods=['GET']) | |
| def health_check(): | |
| return jsonify({"status": "API is healthy"}), 200 | |
| if __name__ == '__main__': | |
| app.run(debug=False, host='0.0.0.0', port=5000) # Run Flask app |