hunthinn's picture
add distill endpoint
f3d38c9
from flask import Flask, jsonify, request
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from huggingface_hub import HfFolder
import os
access_token = os.getenv("HF_ACCESS_TOKEN")
# Authenticate with Hugging Face
HfFolder.save_token(access_token)
# Load the fine-tuned model and tokenizer
tokenizer_path = "gpt2"
small_model_path = "hunthinn/movie_title_gpt2_small"
medium_model_path = "hunthinn/movie_title_gpt2_medium"
distill_model_path = "hunthinn/movie_title_gpt2_distill"
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
distill_model = GPT2LMHeadModel.from_pretrained(distill_model_path)
medium_model = GPT2LMHeadModel.from_pretrained(medium_model_path)
tokenizer.pad_token = tokenizer.eos_token
def infer_title_small(input):
if input:
input_text = "Q: " + input + " A:"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output = small_model.generate(input_ids, max_length=50, num_return_sequences=1)
response = tokenizer.decode(output[0], skip_special_tokens=True)
response = response.split("A:")
return response[-1]
def infer_title_medium(input):
if input:
input_text = "Q: " + input + " A:"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output = medium_model.generate(input_ids, max_length=50, num_return_sequences=1)
response = tokenizer.decode(output[0], skip_special_tokens=True)
response = response.split("A:")
return response[-1]
def infer_title_distill(input):
if input:
input_text = "Q: " + input + " A:"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output = distill_model.generate(
input_ids, max_length=50, num_return_sequences=1
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
response = response.split("A:")
return response[-1]
app = Flask(__name__)
@app.route("/")
def endpoint():
return jsonify({"output": "add small, medium, or distill to use different model"})
@app.route("/small/<input>")
def small_model_endpoint(input):
output = infer_title_small(input)
return jsonify({"output": output})
@app.route("/distill/<input>")
def distill_model_endpoint(input):
output = infer_title_distill(input)
return jsonify({"output": output})
@app.route("/medium/<input>")
def medium_model_endpoint(input):
output = infer_title_medium(input)
return jsonify({"output": output})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)