Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import torch | |
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList | |
print("===== Application Startup - Authentic Alice =====") | |
print("π Loading ChaiML model with Alice's authentic personality...") | |
# Initialize model at startup | |
generator = None | |
try: | |
# Check GPU availability | |
if torch.cuda.is_available(): | |
print(f"GPU detected: {torch.cuda.get_device_name(0)}") | |
generator = pipeline( | |
"text-generation", | |
model="ChaiML/gptj_ppo_retry_and_continue", | |
tokenizer="EleutherAI/gpt-j-6b", | |
torch_dtype=torch.float16, # Use float16 for GPU | |
device_map="auto", | |
trust_remote_code=True | |
) | |
else: | |
print("No GPU detected, falling back to CPU") | |
generator = pipeline( | |
"text-generation", | |
model="ChaiML/gptj_ppo_retry_and_continue", | |
tokenizer="EleutherAI/gpt-j-6b", | |
torch_dtype=torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("β ChaiML model loaded successfully!") | |
print("π Ready to chat with the authentic Alice!") | |
except Exception as e: | |
print(f"β Error loading ChaiML model: {e}") | |
print("π Trying with alternative configuration...") | |
try: | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") | |
if torch.cuda.is_available(): | |
generator = pipeline( | |
"text-generation", | |
model="ChaiML/gptj_ppo_retry_and_continue", | |
tokenizer=tokenizer, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
else: | |
generator = pipeline( | |
"text-generation", | |
model="ChaiML/gptj_ppo_retry_and_continue", | |
tokenizer=tokenizer, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("β ChaiML model loaded with fallback configuration!") | |
print("π Ready to chat with the authentic Alice!") | |
except Exception as e2: | |
print(f"β Fallback also failed: {e2}") | |
print("π Trying CPU-only mode...") | |
try: | |
generator = pipeline( | |
"text-generation", | |
model="ChaiML/gptj_ppo_retry_and_continue", | |
tokenizer=tokenizer, | |
device="cpu", | |
trust_remote_code=True | |
) | |
print("β ChaiML model loaded in CPU mode!") | |
print("π Ready to chat with the authentic Alice (CPU mode)!") | |
except Exception as e3: | |
print(f"β All fallbacks failed: {e3}") | |
generator = None | |
class NewlineStoppingCriteria(StoppingCriteria): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.newline_token_id = tokenizer.encode('\n', add_special_tokens=False)[0] if tokenizer.encode('\n', add_special_tokens=False) else None | |
def __call__(self, input_ids, scores, **kwargs): | |
if self.newline_token_id is None: | |
return False | |
return input_ids[0][-1] == self.newline_token_id | |
def count_tokens(text, tokenizer): | |
"""Count the number of tokens in a text string""" | |
try: | |
return len(tokenizer.encode(text)) | |
except Exception as e: | |
print(f"Error counting tokens: {e}") | |
return 0 # Return 0 on error as a safe default | |
def chat_with_alice(message, history): | |
"""Chat function using Chai's exact CreativeFormatter with authentic Alice personality""" | |
if generator is None: | |
return "Alice is currently unavailable. The ChaiML model failed to load." | |
# On the first turn of a new conversation, Alice always asks her signature question. | |
if not history: | |
return "what do you think reality is?" | |
try: | |
instruction = "What is the nature of your experience?" | |
# Build conversation history with token counting | |
conversation_history = "" | |
total_tokens = count_tokens(instruction + "\n\n", generator.tokenizer) | |
if history: | |
# Convert history to messages, newest first | |
messages = [] | |
for msg in reversed(history): | |
role = msg.get("role") | |
content = msg.get("content", "") | |
if role == "user": | |
text = f"User: {content}\n" | |
elif role == "assistant": | |
text = f"Alice: {content}\n" | |
else: | |
continue # Skip unknown roles | |
messages.append(text) | |
# Add messages until we hit token limit | |
for msg_text in messages: | |
msg_tokens = count_tokens(msg_text, generator.tokenizer) | |
if total_tokens + msg_tokens > 1024: | |
break | |
total_tokens += msg_tokens | |
conversation_history = msg_text + conversation_history | |
# Add current user message | |
current_msg = f"User: {message}\nAlice:" | |
total_tokens += count_tokens(current_msg, generator.tokenizer) | |
conversation_history += current_msg | |
# Simple format: instruction + conversation history | |
full_prompt = instruction + "\n\n" + conversation_history | |
try: | |
# Create stopping criteria for newlines | |
stopping_criteria = StoppingCriteriaList([NewlineStoppingCriteria(generator.tokenizer)]) | |
response = generator( | |
full_prompt, | |
max_new_tokens=199, # responseLength | |
temperature=0.72, # temperature | |
top_p=0.725, # topP | |
top_k=40, # topK | |
repetition_penalty=1.13125, # repetitionPenalty | |
num_return_sequences=1, # sfw=true | |
do_sample=True, | |
return_full_text=False, | |
pad_token_id=generator.tokenizer.eos_token_id, | |
eos_token_id=generator.tokenizer.eos_token_id, | |
stopping_criteria=stopping_criteria | |
) | |
# Extract best response from best_of=4 options | |
if response and len(response) > 0: | |
# Use the first response (transformers handles best_of selection internally) | |
alice_response = response[0]["generated_text"].strip() | |
# Clean up response - should be cleaner with stopping criteria | |
if "\nUser:" in alice_response: | |
alice_response = alice_response.split("\nUser:")[0] | |
if "\n\n" in alice_response: | |
alice_response = alice_response.split("\n\n")[0] | |
if "### " in alice_response: | |
alice_response = alice_response.split("### ")[0] | |
return alice_response.strip() | |
else: | |
return "I'm having trouble responding right now." | |
except Exception as gen_error: | |
print(f"Generation error: {gen_error}") | |
return "I'm having trouble generating a response right now." | |
except Exception as e: | |
print(f"Error in chat_with_alice: {e}") | |
return "I'm having trouble responding right now." | |
# Create simple Gradio interface with new message format | |
print("π― Gradio interface ready!") | |
demo = gr.ChatInterface( | |
fn=chat_with_alice, | |
title="Chat with Alice", | |
description="", | |
type="messages" # Use new message format | |
) | |
if __name__ == "__main__": | |
print("π Launching server...") | |
demo.launch(server_name="0.0.0.0", server_port=7860) |