gojiteji commited on
Commit
a84fac5
1 Parent(s): 2ffae82

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
3
+ modelname="EleutherAI/gpt-neo-2.7B"
4
+ config = AutoConfig.from_pretrained(modelname)
5
+ tokenizer = AutoTokenizer.from_pretrained(modelname)
6
+ model = AutoModelForCausalLM.from_pretrained(modelname,config=config).to("cuda")
7
+
8
+
9
+ def botsay(user_input):
10
+ prompt = "This is a conversation between Human and AI bot. AI's name is ThatGPT."
11
+ new_token_id=None
12
+ gen_tokens=""
13
+ new_token=""
14
+ j =6
15
+ length=0
16
+ limit = 128
17
+ thatid=5562
18
+ cont = True
19
+ last_apppended = False
20
+ cnt=0
21
+ disable_repeat_length= 5
22
+ disable_repeat_count = 2
23
+ tokens=[]
24
+ while(cont):
25
+ cnt+=1
26
+ prob = 1.0
27
+ input_ids=tokenizer(prompt+user_input+"\nAI:"+gen_tokens,return_tensors="pt").input_ids
28
+ length=len(input_ids)
29
+ if length >limit:
30
+ gen_tokens="⚠️sorry length limit. please reload the browser."
31
+ return gen_tokens
32
+ outs=model(input_ids=input_ids.to("cuda"))
33
+ topk = torch.topk(outs.logits.squeeze()[-1,:],k=j+1).indices
34
+ if new_token =="that":
35
+ that_id = 326
36
+ elif new_token ==" that":
37
+ that_id = -1
38
+ elif new_token[-1:] ==" ":
39
+ that_id = 5562
40
+ else:
41
+ that_id = 326
42
+
43
+ if ("thatGPT" in gen_tokens[-12:]):
44
+ that_id = -1
45
+ if last_apppended:
46
+ that_id = -1
47
+ if that_id in topk:
48
+ new_token_id = that_id
49
+ else:
50
+ new_token_id = torch.argmax(outs.logits.squeeze()[-1,:])
51
+ new_token=tokenizer.decode(new_token_id)
52
+ new_token=tokenizer.decode(new_token_id)
53
+ prev_tokens=gen_tokens
54
+ gen_tokens+=new_token
55
+ if (cnt>10) and (disable_repeat_count<gen_tokens.count(gen_tokens[-disable_repeat_length:])):
56
+ gen_tokens=prev_tokens
57
+ new_token = tokenizer.decode(topk[torch.randint(5, (1,1)).item()])
58
+ gen_tokens+=new_token
59
+
60
+ if new_token_id==50256 or new_token_id==198 or new_token=="<|endoftext|>":
61
+ if ("that" not in gen_tokens):
62
+ gen_tokens = gen_tokens.replace("\n","").replace(".","")
63
+ gen_tokens += " that"
64
+ else:
65
+ cont = False
66
+ return gen_tokens.replace("<br>","").replace("AI:","").replace("\xa0","")
67
+
68
+
69
+
70
+
71
+ import gradio as gr
72
+ def add_text(history, text):
73
+ history = history + [(text, None)]
74
+ return history, ""
75
+
76
+
77
+ def bot(history):
78
+ serial_history=""
79
+ for h in history:
80
+ serial_history+="\nHuman:"+h[0]
81
+ if h[1]==None:
82
+ break
83
+ serial_history+="\nAI:"+h[1].replace("<br>","")
84
+
85
+ response = botsay(serial_history)
86
+ history[-1][1] = response
87
+ serial_history+="\nAI:"+response
88
+ return history
89
+
90
+ with gr.Blocks() as demo:
91
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=0.85):
95
+ txt = gr.Textbox(
96
+ show_label=False,
97
+ placeholder="input text and press enter",
98
+ ).style(container=False)
99
+
100
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
101
+ bot, chatbot, chatbot
102
+ )
103
+
104
+ demo.launch(debug=True,share=True)