File size: 3,051 Bytes
a84fac5
 
b614f80
a84fac5
 
e5a7a66
a84fac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5a7a66
a84fac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9c8525
a84fac5
 
 
 
 
 
b614f80
a84fac5
 
 
 
 
 
c7d4eff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
modelname="gpt2-large"
config = AutoConfig.from_pretrained(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
model = AutoModelForCausalLM.from_pretrained(modelname,config=config)


def botsay(user_input):
  prompt = "This is a conversation between Human and AI bot. AI's name is ThatGPT."  
  new_token_id=None
  gen_tokens=""
  new_token=""
  j =6
  length=0
  limit = 128
  thatid=5562
  cont = True
  last_apppended = False
  cnt=0
  disable_repeat_length= 5
  disable_repeat_count = 2
  tokens=[]
  while(cont):
    cnt+=1
    prob = 1.0
    input_ids=tokenizer(prompt+user_input+"\nAI:"+gen_tokens,return_tensors="pt").input_ids
    length=len(input_ids)
    if length >limit:
      gen_tokens="⚠️sorry length limit. please reload the browser."
      return gen_tokens
    outs=model(input_ids=input_ids)
    topk = torch.topk(outs.logits.squeeze()[-1,:],k=j+1).indices
    if new_token =="that":
      that_id = 326
    elif new_token ==" that":
      that_id = -1
    elif new_token[-1:] ==" ":
      that_id = 5562
    else:
      that_id = 326

    if ("thatGPT" in gen_tokens[-12:]):
      that_id = -1
    if last_apppended:
      that_id = -1
    if that_id in topk:
      new_token_id = that_id
    else:
      new_token_id = torch.argmax(outs.logits.squeeze()[-1,:])
      new_token=tokenizer.decode(new_token_id)
    new_token=tokenizer.decode(new_token_id)
    prev_tokens=gen_tokens
    gen_tokens+=new_token
    if (cnt>10) and (disable_repeat_count<gen_tokens.count(gen_tokens[-disable_repeat_length:])):
      gen_tokens=prev_tokens
      new_token = tokenizer.decode(topk[torch.randint(5, (1,1)).item()])
      gen_tokens+=new_token

    if new_token_id==50256 or new_token_id==198 or new_token=="<|endoftext|>":
      if ("that" not in gen_tokens):
        gen_tokens = gen_tokens.replace("\n","").replace(".","")
        gen_tokens += " that"
      else:
        cont = False
  return gen_tokens.replace("<br>","").replace("AI:","").replace("\xa0","")


    

import gradio as gr
def add_text(history, text):
    history = history + [(text, None)]
    return history, ""


def bot(history):
    serial_history=""
    for h in history:
      serial_history+="\nHuman:"+h[0]
      if h[1]==None:
        break
      serial_history+="\nAI:"+h[1].replace("<br>","")

    response = botsay(serial_history)
    history[-1][1] = response
    serial_history+="\nAI:"+response
    return history

with gr.Blocks() as demo:
    gr.Markdown("# ThatGPT - AI always replies with \"that\" -")
    chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)

    with gr.Row():
        with gr.Column(scale=0.85):
            txt = gr.Textbox(
                show_label=False,
                placeholder="AI always replies with \"that\". It may take more than ten seconds.",
            ).style(container=False)
        
    txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
        bot, chatbot, chatbot
    )

demo.launch()