hunthinn commited on
Commit
0bf16f9
1 Parent(s): 6a95174

add more endpoints

Browse files
Files changed (1) hide show
  1. app.py +44 -7
app.py CHANGED
@@ -11,32 +11,69 @@ HfFolder.save_token(access_token)
11
 
12
  # Load the fine-tuned model and tokenizer
13
  tokenizer_path = "gpt2"
14
- model_path = 'hunthinn/movie_title_gpt2'
 
 
 
15
  tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
16
- model = GPT2LMHeadModel.from_pretrained(model_path)
 
 
 
17
  tokenizer.pad_token = tokenizer.eos_token
18
 
19
- def infer_title(input):
20
  if input:
21
  input_text = "Q: " + input + " A:"
22
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
23
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
24
  response = tokenizer.decode(output[0], skip_special_tokens=True)
25
  response = response.split('A:')
26
  return response[-1]
27
 
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  app = Flask(__name__)
30
 
31
 
32
  @app.route("/")
33
  def endpoint():
34
- input = request.args.get("input")
35
 
36
- output = infer_title(input)
37
 
 
 
 
 
38
  return jsonify({"output": output})
39
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
  app.run(host="0.0.0.0", port=7860)
 
11
 
12
  # Load the fine-tuned model and tokenizer
13
  tokenizer_path = "gpt2"
14
+ small_model_path = 'hunthinn/movie_title_gpt2'
15
+ distill_model_path = 'hunthinn/movie_title_gpt2'
16
+ medium_model_path = 'hunthinn/movie_title_gpt2'
17
+
18
  tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
19
+ small_model = GPT2LMHeadModel.from_pretrained(small_model_path)
20
+ distill_model = GPT2LMHeadModel.from_pretrained(distill_model_path)
21
+ medium_model = GPT2LMHeadModel.from_pretrained(medium_model_path)
22
+
23
  tokenizer.pad_token = tokenizer.eos_token
24
 
25
+ def infer_title_small(input):
26
  if input:
27
  input_text = "Q: " + input + " A:"
28
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
29
+ output = small_model.generate(input_ids, max_length=50, num_return_sequences=1)
30
  response = tokenizer.decode(output[0], skip_special_tokens=True)
31
  response = response.split('A:')
32
  return response[-1]
33
 
34
+ def infer_title_medium(input):
35
+ if input:
36
+ input_text = "Q: " + input + " A:"
37
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
38
+ output = medium_model.generate(input_ids, max_length=50, num_return_sequences=1)
39
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
40
+ response = response.split('A:')
41
+ return response[-1]
42
+
43
+ def infer_title_distill(input):
44
+ if input:
45
+ input_text = "Q: " + input + " A:"
46
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
47
+ output = distill_model.generate(input_ids, max_length=50, num_return_sequences=1)
48
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
49
+ response = response.split('A:')
50
+ return response[-1]
51
+
52
  app = Flask(__name__)
53
 
54
 
55
  @app.route("/")
56
  def endpoint():
 
57
 
58
+ return jsonify({"output": "add small, medium, or distill to use different model"})
59
 
60
+ @app.route("/small")
61
+ def small_model_endpoint():
62
+ input = request.args.get("input")
63
+ output = infer_title_small(input)
64
  return jsonify({"output": output})
65
 
66
+ @app.route("/distill")
67
+ def distill_model_endpoint():
68
+ input = request.args.get("input")
69
+ output = infer_title_distill(input)
70
+ return jsonify({"output": output})
71
+
72
+ @app.route("/medium")
73
+ def medium_model_endpoint():
74
+ input = request.args.get("input")
75
+ output = infer_title_medium(input)
76
+ return jsonify({"output": output})
77
 
78
  if __name__ == "__main__":
79
  app.run(host="0.0.0.0", port=7860)