BeveledCube commited on
Commit
1bf6c59
1 Parent(s): 5292931

Fixed GPT2 and made the API POST and GET

Browse files
Files changed (3) hide show
  1. main.py +1 -1
  2. models/gpt2.py +3 -3
  3. templates/index.html +1 -1
main.py CHANGED
@@ -14,7 +14,7 @@ def read_root():
14
  def test_route():
15
  return "This is a test route."
16
 
17
- @app.route("/api", methods=["POST"])
18
  def receive_data():
19
  data = request.get_json()
20
  print("Prompt:", data["prompt"])
 
14
  def test_route():
15
  return "This is a test route."
16
 
17
+ @app.route("/api")
18
  def receive_data():
19
  data = request.get_json()
20
  print("Prompt:", data["prompt"])
models/gpt2.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
 
3
  # https://www.youtube.com/watch?v=irjYqV6EebU
4
 
@@ -8,8 +8,8 @@ def load():
8
  global model
9
  global tokenizer
10
 
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
14
  def generate(input_text):
15
  # Tokenize the input text
 
1
+ from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
2
 
3
  # https://www.youtube.com/watch?v=irjYqV6EebU
4
 
 
8
  global model
9
  global tokenizer
10
 
11
+ model = TFGPT2LMHeadModel.from_pretrained(model_name)
12
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
13
 
14
  def generate(input_text):
15
  # Tokenize the input text
templates/index.html CHANGED
@@ -32,7 +32,7 @@
32
 
33
  <body>
34
  <h1 class="text">Hello there!</h1>
35
- <span class="text">For the API use a POST request</span>
36
  </body>
37
 
38
  </html>
 
32
 
33
  <body>
34
  <h1 class="text">Hello there!</h1>
35
+ <span class="text">For the API use a GET request to /API</span>
36
  </body>
37
 
38
  </html>