BeveledCube commited on
Commit
207c16a
·
verified ·
1 Parent(s): 638158f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -20
main.py CHANGED
@@ -5,7 +5,7 @@ from fastapi import FastAPI
5
 
6
  import os
7
 
8
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  import torch
10
 
11
  app = FastAPI()
@@ -15,15 +15,14 @@ name = "microsoft/DialoGPT-small"
15
  # microsoft/DialoGPT-medium
16
  # microsoft/DialoGPT-large
17
 
18
- # PygmalionAI/pygmalion-350m
19
- # PygmalionAI/pygmalion-1.3b
20
- # PygmalionAI/pygmalion-6b
21
-
22
  # mistralai/Mixtral-8x7B-Instruct-v0.1
23
 
24
  # Load the Hugging Face GPT-2 model and tokenizer
25
- model = GPT2LMHeadModel.from_pretrained(name)
26
- tokenizer = GPT2Tokenizer.from_pretrained(name)
 
 
 
27
 
28
  class req(BaseModel):
29
  prompt: str
@@ -38,16 +37,37 @@ def read_root(data: req):
38
  print("Prompt:", data.prompt)
39
  print("Length:", data.length)
40
 
41
- input_text = data.prompt
42
-
43
- # Tokenize the input text
44
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
45
-
46
- # Generate output using the model
47
- output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
48
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
-
50
- answer_data = { "answer": generated_text }
51
- print("Answer:", generated_text)
52
-
53
- return answer_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import os
7
 
8
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
9
  import torch
10
 
11
  app = FastAPI()
 
15
  # microsoft/DialoGPT-medium
16
  # microsoft/DialoGPT-large
17
 
 
 
 
 
18
  # mistralai/Mixtral-8x7B-Instruct-v0.1
19
 
20
  # Load the Hugging Face GPT-2 model and tokenizer
21
+ model = AutoModelForCausalLM.from_pretrained(name)
22
+ tokenizer = AutoTokenizer.from_pretrained(name)
23
+
24
+ gpt2model = GPT2LMHeadModel.from_pretrained(name)
25
+ gpt2tokenizer = GPT2Tokenizer.from_pretrained(name)
26
 
27
  class req(BaseModel):
28
  prompt: str
 
37
  print("Prompt:", data.prompt)
38
  print("Length:", data.length)
39
 
40
+ if name == "microsoft/DialoGPT-small" or name == "microsoft/DialoGPT-medium" or name == "microsoft/DialoGPT-large":
41
+ # tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
42
+ # model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
43
+
44
+ step = 1
45
+
46
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
47
+ new_user_input_ids = tokenizer.encode(data.prompt + tokenizer.eos_token, return_tensors='pt')
48
+
49
+ # append the new user input tokens to the chat history
50
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
51
+
52
+ # generated a response while limiting the total chat history to 1000 tokens,
53
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
54
+
55
+ generated_text = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
56
+ answer_data = { "answer": generated_text }
57
+ print("Answer:", generated_text)
58
+
59
+ return answer_data
60
+ else:
61
+ input_text = data.prompt
62
+
63
+ # Tokenize the input text
64
+ input_ids = gpt2tokenizer.encode(input_text, return_tensors="pt")
65
+
66
+ # Generate output using the model
67
+ output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
68
+ generated_text = gpt2tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
+
70
+ answer_data = { "answer": generated_text }
71
+ print("Answer:", generated_text)
72
+
73
+ return answer_data