File size: 2,636 Bytes
2b392a5
 
6a95174
 
 
 
 
 
 
 
db6e2f8
2b392a5
0975142
e97e55a
9c814df
f3d38c9
0bf16f9
0975142
0bf16f9
3038e6a
0bf16f9
 
0975142
db6e2f8
a380c9e
0bf16f9
0975142
 
a380c9e
0bf16f9
0975142
a380c9e
0975142
 
a380c9e
0bf16f9
 
 
a380c9e
0bf16f9
 
a380c9e
0bf16f9
a380c9e
d22d434
0bf16f9
 
 
a380c9e
 
 
 
0bf16f9
a380c9e
0bf16f9
d22d434
a380c9e
db6e2f8
 
 
 
2b392a5
a380c9e
0bf16f9
db6e2f8
a380c9e
 
 
 
0bf16f9
db6e2f8
 
d22d434
a380c9e
 
 
0bf16f9
 
d22d434
a380c9e
 
 
 
0bf16f9
a380c9e
 
db6e2f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from flask import Flask, jsonify, request
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from huggingface_hub import HfFolder
import os

access_token = os.getenv("HF_ACCESS_TOKEN")

# Authenticate with Hugging Face

HfFolder.save_token(access_token)

# Load the fine-tuned model and tokenizer
tokenizer_path = "gpt2"
small_model_path = "hunthinn/movie_title_gpt2_small"
medium_model_path = "hunthinn/movie_title_gpt2_medium"
distill_model_path = "hunthinn/movie_title_gpt2_distill"

tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
distill_model = GPT2LMHeadModel.from_pretrained(distill_model_path)
medium_model = GPT2LMHeadModel.from_pretrained(medium_model_path)

tokenizer.pad_token = tokenizer.eos_token


def infer_title_small(input):
    if input:
        input_text = "Q: " + input + " A:"
        input_ids = tokenizer.encode(input_text, return_tensors="pt")
        output = small_model.generate(input_ids, max_length=50, num_return_sequences=1)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        response = response.split("A:")
        return response[-1]


def infer_title_medium(input):
    if input:
        input_text = "Q: " + input + " A:"
        input_ids = tokenizer.encode(input_text, return_tensors="pt")
        output = medium_model.generate(input_ids, max_length=50, num_return_sequences=1)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        response = response.split("A:")
        return response[-1]


def infer_title_distill(input):
    if input:
        input_text = "Q: " + input + " A:"
        input_ids = tokenizer.encode(input_text, return_tensors="pt")
        output = distill_model.generate(
            input_ids, max_length=50, num_return_sequences=1
        )
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        response = response.split("A:")
        return response[-1]


app = Flask(__name__)


@app.route("/")
def endpoint():
    
    return jsonify({"output": "add small, medium, or distill to use different model"})


@app.route("/small/<input>")
def small_model_endpoint(input):

    output = infer_title_small(input)
    return jsonify({"output": output})


@app.route("/distill/<input>")
def distill_model_endpoint(input):

    output = infer_title_distill(input)
    return jsonify({"output": output})


@app.route("/medium/<input>")
def medium_model_endpoint(input):

    output = infer_title_medium(input)
    return jsonify({"output": output})


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)