hunthinn commited on
Commit
d22d434
1 Parent(s): 9c814df

add distill endpoint

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -13,7 +13,7 @@ HfFolder.save_token(access_token)
13
  tokenizer_path = "gpt2"
14
  small_model_path = "hunthinn/movie_title_gpt2_small"
15
  medium_model_path = "hunthinn/movie_title_gpt2_medium"
16
- # distill_model_path = "hunthinn/movie_title_gpt2"
17
 
18
  tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
19
  small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
@@ -42,7 +42,7 @@ def infer_title_medium(input):
42
  response = response.split("A:")
43
  return response[-1]
44
 
45
- """
46
  def infer_title_distill(input):
47
  if input:
48
  input_text = "Q: " + input + " A:"
@@ -53,7 +53,7 @@ def infer_title_distill(input):
53
  response = tokenizer.decode(output[0], skip_special_tokens=True)
54
  response = response.split("A:")
55
  return response[-1]
56
- """
57
 
58
  app = Flask(__name__)
59
 
@@ -70,13 +70,13 @@ def small_model_endpoint(input):
70
  output = infer_title_small(input)
71
  return jsonify({"output": output})
72
 
73
- '''
74
  @app.route("/distill/<input>")
75
  def distill_model_endpoint(input):
76
 
77
  output = infer_title_distill(input)
78
  return jsonify({"output": output})
79
- '''
80
 
81
  @app.route("/medium/<input>")
82
  def medium_model_endpoint(input):
 
13
  tokenizer_path = "gpt2"
14
  small_model_path = "hunthinn/movie_title_gpt2_small"
15
  medium_model_path = "hunthinn/movie_title_gpt2_medium"
16
+ distill_model_path = "hunthinn/movie_title_gpt2"
17
 
18
  tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
19
  small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
 
42
  response = response.split("A:")
43
  return response[-1]
44
 
45
+
46
  def infer_title_distill(input):
47
  if input:
48
  input_text = "Q: " + input + " A:"
 
53
  response = tokenizer.decode(output[0], skip_special_tokens=True)
54
  response = response.split("A:")
55
  return response[-1]
56
+
57
 
58
  app = Flask(__name__)
59
 
 
70
  output = infer_title_small(input)
71
  return jsonify({"output": output})
72
 
73
+
74
  @app.route("/distill/<input>")
75
  def distill_model_endpoint(input):
76
 
77
  output = infer_title_distill(input)
78
  return jsonify({"output": output})
79
+
80
 
81
  @app.route("/medium/<input>")
82
  def medium_model_endpoint(input):