File size: 3,505 Bytes
e14c3a3
 
 
b5897ad
 
e14c3a3
 
 
 
 
 
 
513e175
 
e14c3a3
 
 
 
 
 
 
 
b41103a
b5897ad
b41103a
b5897ad
e4b4477
 
b5897ad
e4b4477
b5897ad
 
 
e14c3a3
b5897ad
b41103a
b5897ad
 
 
 
e14c3a3
d8535e1
 
 
 
 
e14c3a3
d0806a8
b41103a
e4b4477
 
 
d0806a8
e4b4477
 
 
 
 
b5897ad
 
d0806a8
 
 
b5897ad
 
d0806a8
b5897ad
 
 
 
 
 
 
d0806a8
 
 
b5897ad
8a28d56
d0806a8
b41103a
b5897ad
62ca94b
b5897ad
b05c828
e14c3a3
 
b5897ad
e14c3a3
b5897ad
fbcfd77
 
d0806a8
e14c3a3
 
b5897ad
 
e2abd20
 
 
 
 
 
 
e14c3a3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from gradio.inputs import Textbox, Slider

import requests

# Template
title = "A conversation with some NPC in a Tavern 🍻"
description = ""
article = """
<p> If you liked don't forget to 💖 the project 🥰 </p>
<h2> Parameters: </h2>
<ul>
    <li><i>message</i>: what you want to say to the NPC.</li>
    <li><i>npc_name</i>: name of the NPC.</li>
    <li><i>npc_prompt</i>: prompt of the NPC, we can modify it to see if results are better.</li>
    <li><i>top_p</i>:  control how deterministic the model is in generating a response.</li>
    <li><i>temperature</i>: (sampling temperature) higher values means the model will take more risks.</li>
    <li><i>max_new_tokens</i>: Max number of tokens in generation.</li>
</ul>
<img src='http://www.simoninithomas.com/test/gandalf.jpg', alt="Gandalf"/>"""
theme="huggingface"


# Builds the prompt from what previously happened 
def build_prompt(conversation, context, interlocutor_names):
  prompt = context + "\n"
  for player_msg, npc_msg in conversation:
      line = "\n- " + interlocutor_names[0] + ":" + player_msg
      prompt += line
      line = "\n- " + interlocutor_names[1] + ":" + npc_msg
      prompt += line
  prompt += ""
  return prompt

# Recognize what the model said, if it used the correct format
def clean_chat_output(txt, prompt, interlocutor_names):
  delimiter = "\n- "+interlocutor_names[0]
  output = txt.replace(prompt, '')
  output = output[:output.find(delimiter)]
  return output

# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
def query(payload):
  response = requests.post(API_URL, json=payload)
  return response.json()

def chat(message, npc_name, initial_prompt, top_p, temperature, max_new_tokens, history=[]):
    interlocutor_names = ["Player", npc_name]
    
    print("message", message)
    print("npc_name", npc_name)
    print("initial_prompt", initial_prompt)
    print("top_p", top_p)
    print("temperature", temperature)
    print("max_new_tokens", max_new_tokens)
    print("history", history)
    response = "Test"
    history.append((message, ""))
    conversation = history
    
    # Build the prompt
    prompt = build_prompt(conversation, initial_prompt, interlocutor_names)
    
    # Build JSON
    json_req = {"inputs": prompt,
         "parameters":
         {
         "top_p": top_p,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "return_full_text": False
        }}
    
    # Get the output
    output = query(json_req)
    output = output[0]['generated_text']
    print("output", output)
    
    answer = clean_chat_output(output, prompt, interlocutor_names)
    response = answer
    print("response", answer)
    history[-1] = (message, response)
    return history, history


#io = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B")

iface = gr.Interface(fn=chat, 
inputs=[Textbox(label="message"),
        Textbox(label="npc_name"),
        Textbox(label="initial_prompt"),
        Slider(minimum=0.5, maximum=1, step=0.05, default=0.9, label="top_p"),
        Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.1, label="temperature"),
        Slider(minimum=20, maximum=250, step=10, default=50, label="max_new_tokens"),
        "state"],
        outputs=["chatbot","state"],
        #examples="",
        allow_screenshot=True, 
        allow_flagging=True,
        title=title,
        article=article,
        theme=theme)
 
if __name__ == "__main__":
  iface.launch()