Spaces:
Runtime error
Runtime error
add distill endpoint
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ HfFolder.save_token(access_token)
|
|
13 |
tokenizer_path = "gpt2"
|
14 |
small_model_path = "hunthinn/movie_title_gpt2_small"
|
15 |
medium_model_path = "hunthinn/movie_title_gpt2_medium"
|
16 |
-
|
17 |
|
18 |
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
|
19 |
small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
|
@@ -42,7 +42,7 @@ def infer_title_medium(input):
|
|
42 |
response = response.split("A:")
|
43 |
return response[-1]
|
44 |
|
45 |
-
|
46 |
def infer_title_distill(input):
|
47 |
if input:
|
48 |
input_text = "Q: " + input + " A:"
|
@@ -53,7 +53,7 @@ def infer_title_distill(input):
|
|
53 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
54 |
response = response.split("A:")
|
55 |
return response[-1]
|
56 |
-
|
57 |
|
58 |
app = Flask(__name__)
|
59 |
|
@@ -70,13 +70,13 @@ def small_model_endpoint(input):
|
|
70 |
output = infer_title_small(input)
|
71 |
return jsonify({"output": output})
|
72 |
|
73 |
-
|
74 |
@app.route("/distill/<input>")
|
75 |
def distill_model_endpoint(input):
|
76 |
|
77 |
output = infer_title_distill(input)
|
78 |
return jsonify({"output": output})
|
79 |
-
|
80 |
|
81 |
@app.route("/medium/<input>")
|
82 |
def medium_model_endpoint(input):
|
|
|
13 |
tokenizer_path = "gpt2"
|
14 |
small_model_path = "hunthinn/movie_title_gpt2_small"
|
15 |
medium_model_path = "hunthinn/movie_title_gpt2_medium"
|
16 |
+
distill_model_path = "hunthinn/movie_title_gpt2"
|
17 |
|
18 |
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
|
19 |
small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
|
|
|
42 |
response = response.split("A:")
|
43 |
return response[-1]
|
44 |
|
45 |
+
|
46 |
def infer_title_distill(input):
|
47 |
if input:
|
48 |
input_text = "Q: " + input + " A:"
|
|
|
53 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
54 |
response = response.split("A:")
|
55 |
return response[-1]
|
56 |
+
|
57 |
|
58 |
app = Flask(__name__)
|
59 |
|
|
|
70 |
output = infer_title_small(input)
|
71 |
return jsonify({"output": output})
|
72 |
|
73 |
+
|
74 |
@app.route("/distill/<input>")
|
75 |
def distill_model_endpoint(input):
|
76 |
|
77 |
output = infer_title_distill(input)
|
78 |
return jsonify({"output": output})
|
79 |
+
|
80 |
|
81 |
@app.route("/medium/<input>")
|
82 |
def medium_model_endpoint(input):
|