diffusionGPT / app.py
thejagstudio's picture
Update app.py
04b7245 verified
import os
import sys
import json
import time
import importlib.util
from pathlib import Path
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import torch
from transformers import AutoTokenizer
app = Flask(__name__, static_folder='static', static_url_path='/static')
CORS(app)
# Global state
model = None
tokenizer = None
config = None
device = None
DiffusionLLM = None
chat_function = None
def find_file(filename, search_dirs=None):
"""Find a file in current directory or parent directories."""
if search_dirs is None:
search_dirs = [
os.path.dirname(__file__), # Current directory
os.path.dirname(os.path.dirname(__file__)), # Parent directory
os.getcwd(), # Working directory
]
for directory in search_dirs:
filepath = os.path.join(directory, filename)
if os.path.exists(filepath):
print(f"Found {filename} at: {filepath}")
return filepath
return None
def try_import_module(filepath, module_name):
"""Dynamically import a Python file as a module."""
if not filepath or not os.path.exists(filepath):
return None
try:
# Add the directory to sys.path
module_dir = os.path.dirname(filepath)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(module_name, filepath)
if spec is None:
print(f"Could not create spec for {filepath}")
return None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
print(f"Successfully imported {module_name} from {filepath}")
return module
except Exception as e:
print(f"Error importing {filepath}: {e}")
import traceback
traceback.print_exc()
return None
def load_model_internal():
"""Load the model and tokenizer."""
global model, tokenizer, config, device, DiffusionLLM, chat_function
if model is not None:
return True
try:
print("=" * 60)
print("Starting model loading process...")
print("=" * 60)
# Find and import infer-base.py
base_path = find_file("infer-base.py")
if base_path is None:
raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
print(f"\nImporting infer-base.py from: {base_path}")
base_mod = try_import_module(base_path, "infer_base")
if base_mod is None:
raise RuntimeError("Failed to import infer-base.py")
# Check for DiffusionLLM class
if not hasattr(base_mod, 'DiffusionLLM'):
print("Available attributes in infer_base:", dir(base_mod))
raise RuntimeError("DiffusionLLM class not found in infer-base.py")
DiffusionLLM = base_mod.DiffusionLLM
print("βœ“ Successfully loaded DiffusionLLM class")
# Find and import infer-chat.py
chat_path = find_file("infer-chat.py")
if chat_path is None:
raise RuntimeError("Could not find infer-chat.py")
print(f"\nImporting infer-chat.py from: {chat_path}")
chat_mod = try_import_module(chat_path, "infer_chat")
if chat_mod is None or not hasattr(chat_mod, 'chat'):
raise RuntimeError("Failed to import chat function from infer-chat.py")
chat_function = chat_mod.chat
print("βœ“ Successfully loaded chat function")
# Setup pickling workaround for torch.load
try:
if hasattr(base_mod, 'ModelConfig'):
sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
sys.modules['__main__'].DiffusionLLM = DiffusionLLM
print("βœ“ Configured pickle support for model loading")
except Exception as e:
print(f"Warning: Could not setup pickle workaround: {e}")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nβœ“ Using device: {device}")
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ“ Tokenizer loaded")
# Find model checkpoint
checkpoint_dirs = [
"checkpoints",
"../checkpoints",
"./checkpoints",
os.path.join(os.path.dirname(__file__), "checkpoints"),
os.path.join(os.path.dirname(__file__), "../checkpoints"),
]
model_path = None
for checkpoint_dir in checkpoint_dirs:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
if os.path.exists(best_path):
model_path = best_path
break
elif os.path.exists(fp32_path):
model_path = fp32_path
break
if model_path is None:
raise RuntimeError(
"Could not find model checkpoint. Looking for:\n"
" - checkpoints/best_model.pt\n"
" - checkpoints/model_fp32.pt\n"
f"Searched directories: {checkpoint_dirs}"
)
print(f"\nβœ“ Found model checkpoint: {model_path}")
print("Loading model weights (this may take a minute)...")
# Load model
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config = checkpoint['config']
print("Creating model...")
model = DiffusionLLM(config)
print("Loading state dict...")
state_dict = checkpoint['model_state']
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"\n{'=' * 60}")
print(f"βœ“βœ“βœ“ MODEL LOADED SUCCESSFULLY βœ“βœ“βœ“")
print(f"{'=' * 60}")
print(f"Parameters: {num_params:.1f}M")
if 'step' in checkpoint:
print(f"Training steps: {checkpoint['step']}")
if 'best_val_loss' in checkpoint:
print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
print(f"{'=' * 60}\n")
return True
except Exception as e:
print("\n" + "=" * 60)
print("ERROR LOADING MODEL")
print("=" * 60)
print(f"Error: {e}")
import traceback
traceback.print_exc()
print("=" * 60 + "\n")
return False
def create_streaming_visualizer():
"""Create a visualizer that yields SSE events instead of printing to terminal."""
def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
# Normalize inputs to lists
if not isinstance(mask_blocks, list):
mask_blocks = [mask_blocks]
is_masked_list = [is_masked_list]
# Decode context
try:
context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
except Exception:
context_text = str(context_ids[0].tolist())
# Build blocks visualization
all_blocks = []
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
block_tokens = mask_block[0].tolist()
block_data = []
for i, token_id in enumerate(block_tokens):
if is_masked[0, i]:
block_data.append({
'type': 'masked',
'text': 'β–ˆβ–ˆβ–ˆ'
})
else:
try:
token_text = tok.decode([token_id], skip_special_tokens=False)
except Exception:
token_text = str(int(token_id))
block_data.append({
'type': 'revealed',
'text': token_text
})
all_blocks.append({
'block_index': block_idx,
'tokens': block_data
})
# Return data structure that will be sent as SSE
return {
'context': context_text,
'blocks': all_blocks,
'num_blocks': len(mask_blocks)
}
return visualizer
@app.route('/')
def index():
"""Serve the main HTML page."""
return app.send_static_file('index.html')
@app.route('/api/load', methods=['POST'])
def load_model_endpoint():
"""Load the model."""
data = request.json or {}
check_only = data.get('check_only', False)
global model
if check_only:
return jsonify({
'loaded': model is not None,
'message': 'Model is loaded' if model is not None else 'Model not loaded'
})
if model is not None:
return jsonify({
'loaded': True,
'message': 'Model already loaded'
})
success = load_model_internal()
if success:
return jsonify({
'loaded': True,
'message': 'Model loaded successfully'
})
else:
return jsonify({
'loaded': False,
'message': 'Failed to load model. Check server logs for details.'
}), 500
@app.route('/api/generate', methods=['POST'])
def generate():
"""Generate response without streaming."""
global model, tokenizer, config, device, chat_function
if model is None:
return jsonify({'error': 'Model not loaded'}), 400
if chat_function is None:
return jsonify({'error': 'Chat function not available'}), 400
data = request.json
instruction = data.get('instruction', '')
steps = data.get('steps', 64)
block_size = data.get('block_size', 128)
max_new_tokens = data.get('max_new_tokens', 128)
parallel_blocks = data.get('parallel_blocks', 1)
if not instruction:
return jsonify({'error': 'No instruction provided'}), 400
try:
# Generate response
raw_output, response = chat_function(
model,
tokenizer,
instruction,
steps=steps,
block_size=block_size,
max_new_tokens=max_new_tokens,
temperature=0.8,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
verbose=False,
visualize_fn=None,
parallel_blocks=parallel_blocks,
)
return jsonify({
'response': response,
'raw_output': raw_output
})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/generate-stream', methods=['POST'])
def generate_stream():
"""Generate response with streaming visualization."""
global model, tokenizer, config, device, chat_function
if model is None:
return jsonify({'error': 'Model not loaded'}), 400
if chat_function is None:
return jsonify({'error': 'Chat function not available'}), 400
data = request.json
instruction = data.get('instruction', '')
steps = data.get('steps', 64)
block_size = data.get('block_size', 128)
max_new_tokens = data.get('max_new_tokens', 128)
parallel_blocks = data.get('parallel_blocks', 1)
if not instruction:
return jsonify({'error': 'No instruction provided'}), 400
def generate_events():
try:
# Import threading to allow yielding from callback
import queue
event_queue = queue.Queue()
generation_complete = {'done': False, 'result': None}
def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
"""This gets called during generation - we need to send events immediately"""
visualizer = create_streaming_visualizer()
data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
# Put the update in the queue so it can be yielded immediately
event_queue.put({'type': 'update', 'data': data})
# Start generation in a separate thread so we can yield events as they come
import threading
def run_generation():
try:
raw_output, response = chat_function(
model,
tokenizer,
instruction,
steps=steps,
block_size=block_size,
max_new_tokens=max_new_tokens,
temperature=0.8,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
verbose=False,
visualize_fn=streaming_visualizer,
parallel_blocks=parallel_blocks,
)
generation_complete['result'] = (raw_output, response)
except Exception as e:
generation_complete['result'] = ('error', str(e))
finally:
generation_complete['done'] = True
event_queue.put(None) # Signal completion
# Start generation thread
gen_thread = threading.Thread(target=run_generation)
gen_thread.daemon = True
gen_thread.start()
# Yield start event
yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
# Yield events as they come from the queue
while not generation_complete['done'] or not event_queue.empty():
try:
event = event_queue.get(timeout=0.1)
if event is None: # Completion signal
break
yield f"data: {json.dumps(event)}\n\n"
except queue.Empty:
continue
# Wait for thread to finish
gen_thread.join(timeout=1.0)
# Send final response
if generation_complete['result']:
raw_output, response = generation_complete['result']
if raw_output == 'error':
yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
else:
yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
except Exception as e:
import traceback
traceback.print_exc()
yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
return Response(
stream_with_context(generate_events()),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
)
@app.route('/api/test-stream', methods=['GET'])
def test_stream():
"""Test streaming endpoint."""
def generate():
for i in range(10):
yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
time.sleep(0.5)
yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
return Response(
stream_with_context(generate()),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
)
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)