123ABC123456 commited on
Commit
13a8ecd
1 Parent(s): c5529b3

src/chatbots/gptjbot.py

Browse files

from transformers import GPTJForCausalLM, AutoTokenizer
import os
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class DialoGPT:
def __init__(
self,
model_name: str = "EleutherAI/gpt-j-6B",
local_path="./models/gpt-j-6B",
):
if not os.path.exists(local_path):
GPTJForCausalLM.from_pretrained(model_name).save_pretrained(
local_path,
revision="float16",
torch_dtype=torch.float16,
)
AutoTokenizer.from_pretrained(model_name).save_pretrained(local_path)

self.model = GPTJForCausalLM.from_pretrained(
local_path,
revision="float16",
torch_dtype=torch.float16,
)
self.tokenizer = AutoTokenizer.from_pretrained(local_path)

def __call__(self, inputs: str) -> str:
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(device)
generated_ids = self.model.to(device).generate(
input_ids, do_sample=True, temperature=0.9, max_length=200
)
generated_text = self.tokenizer.decode(generated_ids[0])

return generated_text

def run(self):
while True:
user_input = input("User: ")
print("Bot:", self(user_input))


if __name__ == "__main__":
bot = DialoGPT()
bot.run()

Files changed (1) hide show
  1. src/chatbots/dialogpt.py +33 -0
src/chatbots/dialogpt.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from:
3
+ https://www.machinecurve.com/index.php/2021/03/16/easy-chatbot-with-dialogpt-machine-learning-and-huggingface-transformers/
4
+ """
5
+
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import os
9
+
10
+
11
+ class DialoGPT:
12
+ def __init__(
13
+ self,
14
+ model_name: str ='microsoft/DialoGPT-large',
15
+ ):
16
+ if not os.path.exists('./models/dialogpt'):
17
+ AutoModelForCausalLM.from_pretrained(model_name).save_pretrained('./models/dialogpt')
18
+ AutoTokenizer.from_pretrained(model_name).save_pretrained('./models/dialogpt')
19
+
20
+ self.model = AutoModelForCausalLM.from_pretrained('./models/dialogpt')
21
+ self.tokenizer = AutoTokenizer.from_pretrained('./models/dialogpt')
22
+
23
+ def __call__(self, inputs: str) -> str:
24
+ inputs_tokenized = self.tokenizer.encode(inputs+ self.tokenizer.eos_token, return_tensors='pt')
25
+ reply_ids = self.model.generate(inputs_tokenized, max_length=1250, pad_token_id=self.tokenizer.eos_token_id)
26
+ reply = self.tokenizer.decode(reply_ids[:, inputs_tokenized.shape[-1]:][0], skip_special_tokens=True)
27
+
28
+ return reply
29
+
30
+ def run(self):
31
+ while True:
32
+ user_input = input("User: ")
33
+ print("Bot:", self(user_input))