liteDungeon / app.py
Logan Zoellner
lower max_new_tokens
d192b21
from asyncio import constants
import gradio as gr
import requests
import os
import re
import random
from words import *
# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
MAX_NEW_TOKENS = 25
basePrompt="""
The following session was recorded from a text adventure game.
computer: you are an adventurer exploring the darkest dungeon
player: enter dungeon
"""
default_story="computer: you are standing in front of a dark dungeon.\n"
def fallbackResponse():
return "You are attacked by a {monster}!".format(monster=random.choice(monsters))
def continue_story(prompt,story):
print("about to die",basePrompt,story,prompt)
print("huh?",story)
p=basePrompt+story+"player:"+str(prompt)+"\ncomputer:"
print("got prompt:\n\n",p)
print(f"*****Inside desc_generate - Prompt is :{p}")
json_ = {"inputs": p,
"parameters":
{
"top_p": 0.9,
"temperature": 1.1,
"max_new_tokens": MAX_NEW_TOKENS,
"return_full_text": False,
}}
#response = requests.post(API_URL, headers=headers, json=json_)
response = requests.post(API_URL, json=json_)
output = response.json()
print(f"If there was an error? Reason is : {output}")
#error handling
if "error" in output:
print("using fallback description method!")
#fallback method
output_tmp=fallbackResponse()
else:
print("generated text was",output[0]['generated_text'])
output_tmp = output[0]['generated_text']
#strip whitespace
output_tmp = output_tmp.strip()
#truncate response at first newline
if "\n" in output_tmp:
idx = output_tmp.find('\n')
output_tmp = output_tmp[:idx]
#check if response starts with "computer:", if not use fallback
#think I was just being dumb, should have included 'computer:' in the prompt
#if not output_tmp.startswith("computer:"):
# output_tmp = "computer:"+fallbackResponse()
print("which was trimmed to",output_tmp)
#truncate story to last 6 lines
story_tmp = story.split("\n")
if len(story_tmp)>6:
story_tmp = story_tmp[-6:]
story = "\n".join(story_tmp)
#return story
story=story+"player:"+prompt+"\ncomputer: "+output_tmp+"\n"
return story
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>LiteDungeon</center></h1>")
gr.Markdown(
"<div>Create a text adventure, using GPT-J</div>"
)
with gr.Row():
output_story = gr.Textbox(value=default_story,label="story",lines=7)
with gr.Row():
input_command = gr.Textbox(label="input",placeholder="look around")
with gr.Row():
b0 = gr.Button("Submit")
b0.click(continue_story,inputs=[input_command,output_story],outputs=[output_story])
#examples=examples
demo.launch(enable_queue=True, debug=True)