from flask import Flask, jsonify, request from transformers import GPT2Tokenizer, GPT2LMHeadModel from huggingface_hub import HfFolder import os access_token = os.getenv("HF_ACCESS_TOKEN") # Authenticate with Hugging Face HfFolder.save_token(access_token) # Load the fine-tuned model and tokenizer tokenizer_path = "gpt2" small_model_path = "hunthinn/movie_title_gpt2_small" medium_model_path = "hunthinn/movie_title_gpt2_medium" distill_model_path = "hunthinn/movie_title_gpt2_distill" tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) small_model = GPT2LMHeadModel.from_pretrained(small_model_path) distill_model = GPT2LMHeadModel.from_pretrained(distill_model_path) medium_model = GPT2LMHeadModel.from_pretrained(medium_model_path) tokenizer.pad_token = tokenizer.eos_token def infer_title_small(input): if input: input_text = "Q: " + input + " A:" input_ids = tokenizer.encode(input_text, return_tensors="pt") output = small_model.generate(input_ids, max_length=50, num_return_sequences=1) response = tokenizer.decode(output[0], skip_special_tokens=True) response = response.split("A:") return response[-1] def infer_title_medium(input): if input: input_text = "Q: " + input + " A:" input_ids = tokenizer.encode(input_text, return_tensors="pt") output = medium_model.generate(input_ids, max_length=50, num_return_sequences=1) response = tokenizer.decode(output[0], skip_special_tokens=True) response = response.split("A:") return response[-1] def infer_title_distill(input): if input: input_text = "Q: " + input + " A:" input_ids = tokenizer.encode(input_text, return_tensors="pt") output = distill_model.generate( input_ids, max_length=50, num_return_sequences=1 ) response = tokenizer.decode(output[0], skip_special_tokens=True) response = response.split("A:") return response[-1] app = Flask(__name__) @app.route("/") def endpoint(): return jsonify({"output": "add small, medium, or distill to use different model"}) @app.route("/small/") def small_model_endpoint(input): output = infer_title_small(input) return jsonify({"output": output}) @app.route("/distill/") def distill_model_endpoint(input): output = infer_title_distill(input) return jsonify({"output": output}) @app.route("/medium/") def medium_model_endpoint(input): output = infer_title_medium(input) return jsonify({"output": output}) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)