Spaces:
Runtime error
Runtime error
from asyncio import constants | |
import gradio as gr | |
import requests | |
import os | |
import re | |
import random | |
from words import * | |
# GPT-J-6B API | |
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B" | |
MAX_NEW_TOKENS = 25 | |
basePrompt=""" | |
The following session was recorded from a text adventure game. | |
computer: you are an adventurer exploring the darkest dungeon | |
player: enter dungeon | |
""" | |
default_story="computer: you are standing in front of a dark dungeon.\n" | |
def fallbackResponse(): | |
return "You are attacked by a {monster}!".format(monster=random.choice(monsters)) | |
def continue_story(prompt,story): | |
print("about to die",basePrompt,story,prompt) | |
print("huh?",story) | |
p=basePrompt+story+"player:"+str(prompt)+"\ncomputer:" | |
print("got prompt:\n\n",p) | |
print(f"*****Inside desc_generate - Prompt is :{p}") | |
json_ = {"inputs": p, | |
"parameters": | |
{ | |
"top_p": 0.9, | |
"temperature": 1.1, | |
"max_new_tokens": MAX_NEW_TOKENS, | |
"return_full_text": False, | |
}} | |
#response = requests.post(API_URL, headers=headers, json=json_) | |
response = requests.post(API_URL, json=json_) | |
output = response.json() | |
print(f"If there was an error? Reason is : {output}") | |
#error handling | |
if "error" in output: | |
print("using fallback description method!") | |
#fallback method | |
output_tmp=fallbackResponse() | |
else: | |
print("generated text was",output[0]['generated_text']) | |
output_tmp = output[0]['generated_text'] | |
#strip whitespace | |
output_tmp = output_tmp.strip() | |
#truncate response at first newline | |
if "\n" in output_tmp: | |
idx = output_tmp.find('\n') | |
output_tmp = output_tmp[:idx] | |
#check if response starts with "computer:", if not use fallback | |
#think I was just being dumb, should have included 'computer:' in the prompt | |
#if not output_tmp.startswith("computer:"): | |
# output_tmp = "computer:"+fallbackResponse() | |
print("which was trimmed to",output_tmp) | |
#truncate story to last 6 lines | |
story_tmp = story.split("\n") | |
if len(story_tmp)>6: | |
story_tmp = story_tmp[-6:] | |
story = "\n".join(story_tmp) | |
#return story | |
story=story+"player:"+prompt+"\ncomputer: "+output_tmp+"\n" | |
return story | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("<h1><center>LiteDungeon</center></h1>") | |
gr.Markdown( | |
"<div>Create a text adventure, using GPT-J</div>" | |
) | |
with gr.Row(): | |
output_story = gr.Textbox(value=default_story,label="story",lines=7) | |
with gr.Row(): | |
input_command = gr.Textbox(label="input",placeholder="look around") | |
with gr.Row(): | |
b0 = gr.Button("Submit") | |
b0.click(continue_story,inputs=[input_command,output_story],outputs=[output_story]) | |
#examples=examples | |
demo.launch(enable_queue=True, debug=True) |