moin1234 commited on
Commit
22bd20f
·
1 Parent(s): 4405161

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
4
+ model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium")
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
8
+ dataset = load_dataset("text", data_files="SAMANXA.txt")
9
+ tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
+ model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium")
11
+ def tokenize_function(examples):
12
+ return tokenizer(examples["text"], truncation=True)
13
+
14
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
15
+ training_args = TrainingArguments(
16
+ output_dir="./results",
17
+ num_train_epochs=3,
18
+ per_device_train_batch_size=16,
19
+ per_device_eval_batch_size=64,
20
+ warmup_steps=500,
21
+ weight_decay=0.01,
22
+ logging_dir="./logs",
23
+ )
24
+
25
+ trainer = Trainer(
26
+ model=model,
27
+ args=training_args,
28
+ train_dataset=tokenized_datasets["train"],
29
+ )
30
+
31
+ trainer.train()
32
+ def generate_response(message):
33
+ inputs = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
34
+ outputs = model.generate(inputs, max_length=1000, pad_token_id=tokenizer.eos_token_id)
35
+ response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
36
+
37
+ return response
38
+
39
+ def generate_response(message):
40
+ if message == "hello SAMANGPT":
41
+ return "MOOOOIIIIN"
42
+
43
+ inputs = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
44
+ outputs = model.generate(inputs, max_length=1000, pad_token_id=tokenizer.eos_token_id)
45
+ response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
46
+
47
+ return response
48
+ iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
49
+ iface.launch()