Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import time | |
| import importlib.util | |
| from pathlib import Path | |
| from flask import Flask, request, jsonify, Response, stream_with_context | |
| from flask_cors import CORS | |
| import torch | |
| from transformers import AutoTokenizer | |
| app = Flask(__name__, static_folder='static', static_url_path='/static') | |
| CORS(app) | |
| # Global state | |
| model = None | |
| tokenizer = None | |
| config = None | |
| device = None | |
| DiffusionLLM = None | |
| chat_function = None | |
| def find_file(filename, search_dirs=None): | |
| """Find a file in current directory or parent directories.""" | |
| if search_dirs is None: | |
| search_dirs = [ | |
| os.path.dirname(__file__), # Current directory | |
| os.path.dirname(os.path.dirname(__file__)), # Parent directory | |
| os.getcwd(), # Working directory | |
| ] | |
| for directory in search_dirs: | |
| filepath = os.path.join(directory, filename) | |
| if os.path.exists(filepath): | |
| print(f"Found {filename} at: {filepath}") | |
| return filepath | |
| return None | |
| def try_import_module(filepath, module_name): | |
| """Dynamically import a Python file as a module.""" | |
| if not filepath or not os.path.exists(filepath): | |
| return None | |
| try: | |
| # Add the directory to sys.path | |
| module_dir = os.path.dirname(filepath) | |
| if module_dir not in sys.path: | |
| sys.path.insert(0, module_dir) | |
| spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| if spec is None: | |
| print(f"Could not create spec for {filepath}") | |
| return None | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| print(f"Successfully imported {module_name} from {filepath}") | |
| return module | |
| except Exception as e: | |
| print(f"Error importing {filepath}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def load_model_internal(): | |
| """Load the model and tokenizer.""" | |
| global model, tokenizer, config, device, DiffusionLLM, chat_function | |
| if model is not None: | |
| return True | |
| try: | |
| print("=" * 60) | |
| print("Starting model loading process...") | |
| print("=" * 60) | |
| # Find and import infer-base.py | |
| base_path = find_file("infer-base.py") | |
| if base_path is None: | |
| raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.") | |
| print(f"\nImporting infer-base.py from: {base_path}") | |
| base_mod = try_import_module(base_path, "infer_base") | |
| if base_mod is None: | |
| raise RuntimeError("Failed to import infer-base.py") | |
| # Check for DiffusionLLM class | |
| if not hasattr(base_mod, 'DiffusionLLM'): | |
| print("Available attributes in infer_base:", dir(base_mod)) | |
| raise RuntimeError("DiffusionLLM class not found in infer-base.py") | |
| DiffusionLLM = base_mod.DiffusionLLM | |
| print("β Successfully loaded DiffusionLLM class") | |
| # Find and import infer-chat.py | |
| chat_path = find_file("infer-chat.py") | |
| if chat_path is None: | |
| raise RuntimeError("Could not find infer-chat.py") | |
| print(f"\nImporting infer-chat.py from: {chat_path}") | |
| chat_mod = try_import_module(chat_path, "infer_chat") | |
| if chat_mod is None or not hasattr(chat_mod, 'chat'): | |
| raise RuntimeError("Failed to import chat function from infer-chat.py") | |
| chat_function = chat_mod.chat | |
| print("β Successfully loaded chat function") | |
| # Setup pickling workaround for torch.load | |
| try: | |
| if hasattr(base_mod, 'ModelConfig'): | |
| sys.modules['__main__'].ModelConfig = base_mod.ModelConfig | |
| sys.modules['__main__'].DiffusionLLM = DiffusionLLM | |
| print("β Configured pickle support for model loading") | |
| except Exception as e: | |
| print(f"Warning: Could not setup pickle workaround: {e}") | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\nβ Using device: {device}") | |
| # Load tokenizer | |
| print("\nLoading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("β Tokenizer loaded") | |
| # Find model checkpoint | |
| checkpoint_dirs = [ | |
| "checkpoints", | |
| "../checkpoints", | |
| "./checkpoints", | |
| os.path.join(os.path.dirname(__file__), "checkpoints"), | |
| os.path.join(os.path.dirname(__file__), "../checkpoints"), | |
| ] | |
| model_path = None | |
| for checkpoint_dir in checkpoint_dirs: | |
| best_path = os.path.join(checkpoint_dir, "best_model.pt") | |
| fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") | |
| if os.path.exists(best_path): | |
| model_path = best_path | |
| break | |
| elif os.path.exists(fp32_path): | |
| model_path = fp32_path | |
| break | |
| if model_path is None: | |
| raise RuntimeError( | |
| "Could not find model checkpoint. Looking for:\n" | |
| " - checkpoints/best_model.pt\n" | |
| " - checkpoints/model_fp32.pt\n" | |
| f"Searched directories: {checkpoint_dirs}" | |
| ) | |
| print(f"\nβ Found model checkpoint: {model_path}") | |
| print("Loading model weights (this may take a minute)...") | |
| # Load model | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| config = checkpoint['config'] | |
| print("Creating model...") | |
| model = DiffusionLLM(config) | |
| print("Loading state dict...") | |
| state_dict = checkpoint['model_state'] | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| print(f"\n{'=' * 60}") | |
| print(f"βββ MODEL LOADED SUCCESSFULLY βββ") | |
| print(f"{'=' * 60}") | |
| print(f"Parameters: {num_params:.1f}M") | |
| if 'step' in checkpoint: | |
| print(f"Training steps: {checkpoint['step']}") | |
| if 'best_val_loss' in checkpoint: | |
| print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}") | |
| print(f"{'=' * 60}\n") | |
| return True | |
| except Exception as e: | |
| print("\n" + "=" * 60) | |
| print("ERROR LOADING MODEL") | |
| print("=" * 60) | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("=" * 60 + "\n") | |
| return False | |
| def create_streaming_visualizer(): | |
| """Create a visualizer that yields SSE events instead of printing to terminal.""" | |
| def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| # Normalize inputs to lists | |
| if not isinstance(mask_blocks, list): | |
| mask_blocks = [mask_blocks] | |
| is_masked_list = [is_masked_list] | |
| # Decode context | |
| try: | |
| context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| except Exception: | |
| context_text = str(context_ids[0].tolist()) | |
| # Build blocks visualization | |
| all_blocks = [] | |
| for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| block_tokens = mask_block[0].tolist() | |
| block_data = [] | |
| for i, token_id in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| block_data.append({ | |
| 'type': 'masked', | |
| 'text': 'βββ' | |
| }) | |
| else: | |
| try: | |
| token_text = tok.decode([token_id], skip_special_tokens=False) | |
| except Exception: | |
| token_text = str(int(token_id)) | |
| block_data.append({ | |
| 'type': 'revealed', | |
| 'text': token_text | |
| }) | |
| all_blocks.append({ | |
| 'block_index': block_idx, | |
| 'tokens': block_data | |
| }) | |
| # Return data structure that will be sent as SSE | |
| return { | |
| 'context': context_text, | |
| 'blocks': all_blocks, | |
| 'num_blocks': len(mask_blocks) | |
| } | |
| return visualizer | |
| def index(): | |
| """Serve the main HTML page.""" | |
| return app.send_static_file('index.html') | |
| def load_model_endpoint(): | |
| """Load the model.""" | |
| data = request.json or {} | |
| check_only = data.get('check_only', False) | |
| global model | |
| if check_only: | |
| return jsonify({ | |
| 'loaded': model is not None, | |
| 'message': 'Model is loaded' if model is not None else 'Model not loaded' | |
| }) | |
| if model is not None: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model already loaded' | |
| }) | |
| success = load_model_internal() | |
| if success: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model loaded successfully' | |
| }) | |
| else: | |
| return jsonify({ | |
| 'loaded': False, | |
| 'message': 'Failed to load model. Check server logs for details.' | |
| }), 500 | |
| def generate(): | |
| """Generate response without streaming.""" | |
| global model, tokenizer, config, device, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| if chat_function is None: | |
| return jsonify({'error': 'Chat function not available'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 64) | |
| block_size = data.get('block_size', 128) | |
| max_new_tokens = data.get('max_new_tokens', 128) | |
| parallel_blocks = data.get('parallel_blocks', 1) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| try: | |
| # Generate response | |
| raw_output, response = chat_function( | |
| model, | |
| tokenizer, | |
| instruction, | |
| steps=steps, | |
| block_size=block_size, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| verbose=False, | |
| visualize_fn=None, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| return jsonify({ | |
| 'response': response, | |
| 'raw_output': raw_output | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def generate_stream(): | |
| """Generate response with streaming visualization.""" | |
| global model, tokenizer, config, device, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| if chat_function is None: | |
| return jsonify({'error': 'Chat function not available'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 64) | |
| block_size = data.get('block_size', 128) | |
| max_new_tokens = data.get('max_new_tokens', 128) | |
| parallel_blocks = data.get('parallel_blocks', 1) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| def generate_events(): | |
| try: | |
| # Import threading to allow yielding from callback | |
| import queue | |
| event_queue = queue.Queue() | |
| generation_complete = {'done': False, 'result': None} | |
| def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| """This gets called during generation - we need to send events immediately""" | |
| visualizer = create_streaming_visualizer() | |
| data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) | |
| # Put the update in the queue so it can be yielded immediately | |
| event_queue.put({'type': 'update', 'data': data}) | |
| # Start generation in a separate thread so we can yield events as they come | |
| import threading | |
| def run_generation(): | |
| try: | |
| raw_output, response = chat_function( | |
| model, | |
| tokenizer, | |
| instruction, | |
| steps=steps, | |
| block_size=block_size, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| verbose=False, | |
| visualize_fn=streaming_visualizer, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| generation_complete['result'] = (raw_output, response) | |
| except Exception as e: | |
| generation_complete['result'] = ('error', str(e)) | |
| finally: | |
| generation_complete['done'] = True | |
| event_queue.put(None) # Signal completion | |
| # Start generation thread | |
| gen_thread = threading.Thread(target=run_generation) | |
| gen_thread.daemon = True | |
| gen_thread.start() | |
| # Yield start event | |
| yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n" | |
| # Yield events as they come from the queue | |
| while not generation_complete['done'] or not event_queue.empty(): | |
| try: | |
| event = event_queue.get(timeout=0.1) | |
| if event is None: # Completion signal | |
| break | |
| yield f"data: {json.dumps(event)}\n\n" | |
| except queue.Empty: | |
| continue | |
| # Wait for thread to finish | |
| gen_thread.join(timeout=1.0) | |
| # Send final response | |
| if generation_complete['result']: | |
| raw_output, response = generation_complete['result'] | |
| if raw_output == 'error': | |
| yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" | |
| else: | |
| yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" | |
| return Response( | |
| stream_with_context(generate_events()), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'X-Accel-Buffering': 'no' | |
| } | |
| ) | |
| def test_stream(): | |
| """Test streaming endpoint.""" | |
| def generate(): | |
| for i in range(10): | |
| yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n" | |
| time.sleep(0.5) | |
| yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n" | |
| return Response( | |
| stream_with_context(generate()), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'X-Accel-Buffering': 'no' | |
| } | |
| ) | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |