hunthinn's picture
load my model
0975142
raw
history blame
971 Bytes
from flask import Flask, jsonify, request
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Load the fine-tuned model and tokenizer
tokenizer_path = "gpt2"
model_path = 'hunthinn/movie_title_gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
def infer_title(input):
if input:
input_text = "Q: " + input + " A:"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output = model.generate(input_ids, max_length=50, num_return_sequences=1)
response = tokenizer.decode(output[0], skip_special_tokens=True)
response = response.split('A:')
return response[-1]
app = Flask(__name__)
@app.route("/")
def endpoint():
input = request.args.get("input")
output = infer_title(input)
return jsonify({"output": output})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)