alice / app.py
andrewblevins's picture
Add opening question
4c4dbaf
import gradio as gr
import os
import torch
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList
print("===== Application Startup - Authentic Alice =====")
print("πŸš€ Loading ChaiML model with Alice's authentic personality...")
# Initialize model at startup
generator = None
try:
# Check GPU availability
if torch.cuda.is_available():
print(f"GPU detected: {torch.cuda.get_device_name(0)}")
generator = pipeline(
"text-generation",
model="ChaiML/gptj_ppo_retry_and_continue",
tokenizer="EleutherAI/gpt-j-6b",
torch_dtype=torch.float16, # Use float16 for GPU
device_map="auto",
trust_remote_code=True
)
else:
print("No GPU detected, falling back to CPU")
generator = pipeline(
"text-generation",
model="ChaiML/gptj_ppo_retry_and_continue",
tokenizer="EleutherAI/gpt-j-6b",
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True
)
print("βœ… ChaiML model loaded successfully!")
print("🎭 Ready to chat with the authentic Alice!")
except Exception as e:
print(f"❌ Error loading ChaiML model: {e}")
print("πŸ”„ Trying with alternative configuration...")
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
if torch.cuda.is_available():
generator = pipeline(
"text-generation",
model="ChaiML/gptj_ppo_retry_and_continue",
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
else:
generator = pipeline(
"text-generation",
model="ChaiML/gptj_ppo_retry_and_continue",
tokenizer=tokenizer,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True
)
print("βœ… ChaiML model loaded with fallback configuration!")
print("🎭 Ready to chat with the authentic Alice!")
except Exception as e2:
print(f"❌ Fallback also failed: {e2}")
print("πŸ”„ Trying CPU-only mode...")
try:
generator = pipeline(
"text-generation",
model="ChaiML/gptj_ppo_retry_and_continue",
tokenizer=tokenizer,
device="cpu",
trust_remote_code=True
)
print("βœ… ChaiML model loaded in CPU mode!")
print("🎭 Ready to chat with the authentic Alice (CPU mode)!")
except Exception as e3:
print(f"❌ All fallbacks failed: {e3}")
generator = None
class NewlineStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.newline_token_id = tokenizer.encode('\n', add_special_tokens=False)[0] if tokenizer.encode('\n', add_special_tokens=False) else None
def __call__(self, input_ids, scores, **kwargs):
if self.newline_token_id is None:
return False
return input_ids[0][-1] == self.newline_token_id
def count_tokens(text, tokenizer):
"""Count the number of tokens in a text string"""
try:
return len(tokenizer.encode(text))
except Exception as e:
print(f"Error counting tokens: {e}")
return 0 # Return 0 on error as a safe default
def chat_with_alice(message, history):
"""Chat function using Chai's exact CreativeFormatter with authentic Alice personality"""
if generator is None:
return "Alice is currently unavailable. The ChaiML model failed to load."
# On the first turn of a new conversation, Alice always asks her signature question.
if not history:
return "what do you think reality is?"
try:
instruction = "What is the nature of your experience?"
# Build conversation history with token counting
conversation_history = ""
total_tokens = count_tokens(instruction + "\n\n", generator.tokenizer)
if history:
# Convert history to messages, newest first
messages = []
for msg in reversed(history):
role = msg.get("role")
content = msg.get("content", "")
if role == "user":
text = f"User: {content}\n"
elif role == "assistant":
text = f"Alice: {content}\n"
else:
continue # Skip unknown roles
messages.append(text)
# Add messages until we hit token limit
for msg_text in messages:
msg_tokens = count_tokens(msg_text, generator.tokenizer)
if total_tokens + msg_tokens > 1024:
break
total_tokens += msg_tokens
conversation_history = msg_text + conversation_history
# Add current user message
current_msg = f"User: {message}\nAlice:"
total_tokens += count_tokens(current_msg, generator.tokenizer)
conversation_history += current_msg
# Simple format: instruction + conversation history
full_prompt = instruction + "\n\n" + conversation_history
try:
# Create stopping criteria for newlines
stopping_criteria = StoppingCriteriaList([NewlineStoppingCriteria(generator.tokenizer)])
response = generator(
full_prompt,
max_new_tokens=199, # responseLength
temperature=0.72, # temperature
top_p=0.725, # topP
top_k=40, # topK
repetition_penalty=1.13125, # repetitionPenalty
num_return_sequences=1, # sfw=true
do_sample=True,
return_full_text=False,
pad_token_id=generator.tokenizer.eos_token_id,
eos_token_id=generator.tokenizer.eos_token_id,
stopping_criteria=stopping_criteria
)
# Extract best response from best_of=4 options
if response and len(response) > 0:
# Use the first response (transformers handles best_of selection internally)
alice_response = response[0]["generated_text"].strip()
# Clean up response - should be cleaner with stopping criteria
if "\nUser:" in alice_response:
alice_response = alice_response.split("\nUser:")[0]
if "\n\n" in alice_response:
alice_response = alice_response.split("\n\n")[0]
if "### " in alice_response:
alice_response = alice_response.split("### ")[0]
return alice_response.strip()
else:
return "I'm having trouble responding right now."
except Exception as gen_error:
print(f"Generation error: {gen_error}")
return "I'm having trouble generating a response right now."
except Exception as e:
print(f"Error in chat_with_alice: {e}")
return "I'm having trouble responding right now."
# Create simple Gradio interface with new message format
print("🎯 Gradio interface ready!")
demo = gr.ChatInterface(
fn=chat_with_alice,
title="Chat with Alice",
description="",
type="messages" # Use new message format
)
if __name__ == "__main__":
print("🌐 Launching server...")
demo.launch(server_name="0.0.0.0", server_port=7860)