#!/usr/bin/env python3 """ AR-Diffusion Chat Interface for Hugging Face Spaces Experimental model with Quality vs Speed modes Optimized for Zero GPU deployment with @spaces.GPU """ import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import random import numpy as np import re import time from typing import List, Tuple import os import gc import spaces # Global model variables for memory efficiency tokenizer = None model = None device = None class ARDiffusionGenerator: """Base AR-Diffusion generator with shared functionality""" def __init__(self, tokenizer, model, device): self.tokenizer = tokenizer self.model = model self.device = device self.mask_token_id = self._find_mask_token() def _find_mask_token(self) -> int: """Find MASK token ID""" for candidate in ['MASK', '', '[MASK]', '<|mask|>']: try: tokens = self.tokenizer.encode(candidate, add_special_tokens=False) if len(tokens) == 1: return tokens[0] except: continue return getattr(self.tokenizer, 'unk_token_id', 50257) or 50257 def create_prompt(self, instruction: str) -> str: """Create Alpaca-style prompt""" return f"""### Instruction: {instruction} ### Response: """ class QualityGenerator(ARDiffusionGenerator): """Quality-focused AR-Diffusion generator""" def filter_logits(self, logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0, temperature: float = 1.0) -> torch.Tensor: """Research-grade filtering with proper order""" original_shape = logits.shape if logits.dim() == 3: logits = logits.squeeze(0) elif logits.dim() == 1: logits = logits.unsqueeze(0) logits = logits.clone() # Temperature scaling first if temperature != 1.0: logits = logits / temperature # Top-k filtering if top_k > 0 and top_k < logits.size(-1): topk_vals, _ = torch.topk(logits, top_k, dim=-1) thresholds = topk_vals[:, -1].unsqueeze(-1) logits = torch.where(logits < thresholds, torch.full_like(logits, float("-inf")), logits) # Top-p filtering if top_p > 0.0 and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = torch.softmax(sorted_logits, dim=-1) cum_probs = probs.cumsum(dim=-1) mask = cum_probs > top_p mask[:, 0] = False scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter( dim=-1, index=sorted_indices, src=mask) logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) # Restore original shape if len(original_shape) == 1: logits = logits.squeeze(0) elif original_shape[0] == 1 and logits.dim() == 2: logits = logits.unsqueeze(0) return logits def generate_start(self, prompt: str, length: int = 8) -> List[int]: """Generate natural start""" tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) input_ids = tokens['input_ids'][0] generated = [] current = input_ids.clone() with torch.no_grad(): for _ in range(length): outputs = self.model(input_ids=current.unsqueeze(0)) logits = outputs.logits[0, -1] filtered_logits = self.filter_logits( logits, top_k=50, top_p=0.9, temperature=0.8 ) probs = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probs, 1).item() if next_token in [self.tokenizer.eos_token_id, 128001, 13]: break generated.append(next_token) current = torch.cat([current, torch.tensor([next_token], device=self.device)]) return generated def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: """Create corrupted sequence for quality mode""" prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] natural_start = self.generate_start(prompt, length=random.randint(8, 12)) # Longer sequences for better quality prompt_length = len(prompt_tokens) if prompt_length > 25: num_masks = random.randint(35, 50) elif prompt_length > 15: num_masks = random.randint(25, 40) else: num_masks = random.randint(20, 35) sequence = ( prompt_tokens.tolist() + natural_start + [self.mask_token_id] * num_masks + [13] ) tensor = torch.tensor(sequence) text = self.tokenizer.decode(tensor, skip_special_tokens=False) return text, tensor def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: """Quality generation with progress updates and speed tracking""" steps = 40 temperature = 0.7 start_time = time.time() if progress_callback: progress_callback(0.1, "Creating sequence...") full_prompt = self.create_prompt(prompt) corrupted_text, corrupted_ids = self.create_sequence(full_prompt) if progress_callback: progress_callback(0.2, "Starting quality denoising...") result, stats = self._denoise_quality(corrupted_ids, steps, temperature, progress_callback) # Calculate overall stats total_time = time.time() - start_time response = self._clean_response(result) word_count = len(response.split()) stats.update({ 'total_time': total_time, 'word_count': word_count, 'words_per_second': word_count / total_time if total_time > 0 else 0 }) return response, stats def _denoise_quality(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: """Quality denoising with progress updates and speed tracking""" current_ids = corrupted_ids.clone() total_replacements = 0 start_time = time.time() for step in range(steps): if progress_callback: progress = 0.2 + (step / steps) * 0.7 elapsed = time.time() - start_time tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 progress_callback(progress, f"Quality step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") mask_positions = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] if len(mask_positions) == 0: break with torch.no_grad(): outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) logits = outputs.logits[0] current_temp = max(0.4, temperature * (1 - step / steps)) # Conservative replacement for quality if step < steps // 4: max_replacements = min(1, len(mask_positions)) elif step < steps // 2: max_replacements = min(2, len(mask_positions)) else: max_replacements = min(3, len(mask_positions)) sorted_positions = sorted(mask_positions.tolist()) for pos in sorted_positions[:max_replacements]: if pos < len(logits): token_logits = logits[pos].clone() # Anti-repetition context_start = max(0, pos - 5) recent_tokens = set(current_ids[context_start:pos].tolist()) for recent_token in recent_tokens: if recent_token < len(token_logits): token_logits[recent_token] -= 8.0 # Quality filtering filtered_logits = self.filter_logits( token_logits, top_k=30, top_p=0.75, temperature=current_temp ) probs = F.softmax(filtered_logits, dim=-1) probs = torch.clamp(probs, min=1e-8, max=1.0) new_token = torch.multinomial(probs, 1).item() # Filter unwanted tokens unwanted = [self.mask_token_id, 128001, 128000] if new_token in unwanted: top_k_vals, top_k_indices = torch.topk(filtered_logits, 10) for alternative in top_k_indices: if alternative.item() not in unwanted: new_token = alternative.item() break current_ids[pos] = new_token total_replacements += 1 if progress_callback: elapsed = time.time() - start_time final_speed = total_replacements / elapsed if elapsed > 0 else 0 progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") # Calculate final statistics total_time = time.time() - start_time stats = { 'mode': 'Quality', 'steps': steps, 'tokens_replaced': total_replacements, 'generation_time': total_time, 'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 } result = self.tokenizer.decode(current_ids, skip_special_tokens=True) return result, stats def _clean_response(self, text: str) -> str: """Clean response for quality output""" if "### Response:" in text: response = text.split("### Response:")[-1].strip() else: response = text.strip() if not response: return text # Quality cleaning response = re.sub(r"'{2,}", "", response) response = re.sub(r'"{2,}', "", response) response = re.sub(r"\.{2,}", ".", response) response = re.sub(r",{2,}", ",", response) response = re.sub(r"\s+", " ", response) # Remove artifacts response = re.sub(r"\$+", "", response) response = re.sub(r"#+", "", response) response = re.sub(r"@+", "", response) response = response.strip() if response and not response.endswith(('.', '!', '?')): response += "." return response class SpeedGenerator(ARDiffusionGenerator): """Speed-focused AR-Diffusion generator""" def filter_logits(self, logits: torch.Tensor, top_k: int = 15, top_p: float = 0.8, temperature: float = 1.0) -> torch.Tensor: """Fast logits filtering""" logits = logits.clone() if temperature != 1.0: logits = logits / temperature # Top-k filtering if top_k > 0 and top_k < logits.size(-1): topk_vals, _ = torch.topk(logits, top_k, dim=-1) threshold = topk_vals[-1] logits = torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits) # Top-p filtering if top_p > 0.0 and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = torch.softmax(sorted_logits, dim=-1) cum_probs = probs.cumsum(dim=-1) mask = cum_probs > top_p mask[0] = False scatter_mask = torch.zeros_like(logits, dtype=torch.bool) scatter_mask.scatter_(0, sorted_indices, mask) logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) return logits def generate_start(self, prompt: str, length: int = 6) -> List[int]: """Generate natural start for speed mode""" tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) input_ids = tokens['input_ids'][0] generated = [] current = input_ids.clone() with torch.no_grad(): for _ in range(length): outputs = self.model(input_ids=current.unsqueeze(0)) logits = outputs.logits[0, -1] filtered_logits = self.filter_logits(logits, top_k=20, top_p=0.9, temperature=0.8) probs = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probs, 1).item() if next_token in [self.tokenizer.eos_token_id, 128001, 13]: break generated.append(next_token) current = torch.cat([current, torch.tensor([next_token], device=self.device)]) return generated def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: """Create sequence optimized for speed""" prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] natural_start = self.generate_start(prompt, length=6) # Shorter sequences for speed prompt_words = len(prompt.split()) if prompt_words > 8: num_masks = random.randint(15, 25) else: num_masks = random.randint(12, 20) sequence = ( prompt_tokens.tolist() + natural_start + [self.mask_token_id] * num_masks + [13] ) tensor = torch.tensor(sequence) text = self.tokenizer.decode(tensor, skip_special_tokens=False) return text, tensor def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: """Speed generation with progress updates and speed tracking""" steps = 10 temperature = 0.8 start_time = time.time() if progress_callback: progress_callback(0.1, "Creating sequence...") full_prompt = self.create_prompt(prompt) corrupted_text, corrupted_ids = self.create_sequence(full_prompt) if progress_callback: progress_callback(0.2, "Starting speed denoising...") result, stats = self._denoise_speed(corrupted_ids, steps, temperature, progress_callback) # Calculate overall stats total_time = time.time() - start_time response = self._clean_response(result) word_count = len(response.split()) stats.update({ 'total_time': total_time, 'word_count': word_count, 'words_per_second': word_count / total_time if total_time > 0 else 0 }) return response, stats def _denoise_speed(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: """Ultra-fast denoising with progress updates and speed tracking""" current_ids = corrupted_ids.clone() total_replacements = 0 start_time = time.time() # Use mixed precision for speed on GPU with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.device.type == 'cuda'): for step in range(steps): if progress_callback: progress = 0.2 + (step / steps) * 0.7 elapsed = time.time() - start_time tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 progress_callback(progress, f"Speed step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") mask_pos = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] if len(mask_pos) == 0: break with torch.no_grad(): outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) logits = outputs.logits[0] current_temp = temperature * (0.9 + 0.2 * (step / steps)) # Aggressive replacement for speed max_replace = min(8, len(mask_pos)) positions = sorted(mask_pos.tolist())[:max_replace] for pos in positions: if pos < len(logits): token_logits = logits[pos].clone() # Light anti-repetition recent_start = max(0, pos - 3) recent_tokens = set(current_ids[recent_start:pos].tolist()) for token in recent_tokens: if token < len(token_logits): token_logits[token] -= 3.0 # Fast filtering filtered_logits = self.filter_logits( token_logits, top_k=12, top_p=0.85, temperature=current_temp ) probs = F.softmax(filtered_logits, dim=-1) probs = torch.clamp(probs, min=1e-8, max=1.0) new_token = torch.multinomial(probs, 1).item() # Quick filtering if new_token in [self.mask_token_id, 128001, 128000]: top_vals, top_indices = torch.topk(filtered_logits, 3) new_token = top_indices[1].item() current_ids[pos] = new_token total_replacements += 1 if progress_callback: elapsed = time.time() - start_time final_speed = total_replacements / elapsed if elapsed > 0 else 0 progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") # Calculate final statistics total_time = time.time() - start_time stats = { 'mode': 'Speed', 'steps': steps, 'tokens_replaced': total_replacements, 'generation_time': total_time, 'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 } result = self.tokenizer.decode(current_ids, skip_special_tokens=True) return result, stats def _clean_response(self, text: str) -> str: """Clean response for speed output""" if "### Response:" in text: response = text.split("### Response:")[-1].strip() else: response = text.strip() if not response: return text # Minimal cleaning for speed response = re.sub(r"'{3,}", "", response) response = re.sub(r'"{3,}', "", response) response = re.sub(r"\.{3,}", ".", response) response = re.sub(r",{3,}", ",", response) response = re.sub(r"\s+", " ", response) response = response.strip() if response and not response.endswith(('.', '!', '?')): response += "." return response @spaces.GPU def load_model(): """Load model with Zero GPU optimization using @spaces.GPU""" global tokenizer, model, device if tokenizer is not None and model is not None: return tokenizer, model, device # Get HF token from environment hf_token = os.getenv("HF_TOKEN") if hf_token: print("๐Ÿ”‘ HF_TOKEN found - using authenticated access") else: print("โš ๏ธ No HF_TOKEN found - using public access only") try: # This appears to be a LoRA adapter adapter_path = "rootxhacker/llama-3B-diffusion-exp-fixed" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading AR-Diffusion model on {device}...") # Load tokenizer from adapter with token tokenizer = AutoTokenizer.from_pretrained( adapter_path, trust_remote_code=True, token=hf_token ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load the adapter model with token print("Loading adapter model...") model = AutoModelForCausalLM.from_pretrained( adapter_path, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, device_map="auto" if device.type == "cuda" else None, trust_remote_code=True, low_cpu_mem_usage=True, token=hf_token ) print("โœ… AR-Diffusion model loaded successfully!") return tokenizer, model, device except Exception as e: print(f"โŒ Error loading {adapter_path}: {e}") # Try alternative working models for AR-Diffusion demo print("๐Ÿ”„ Trying alternative models...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Try different models in order of preference alternative_models = [ "microsoft/DialoGPT-medium", "gpt2-large", "gpt2-medium", "distilgpt2" ] for alt_model in alternative_models: try: print(f"Trying {alt_model}...") tokenizer = AutoTokenizer.from_pretrained(alt_model, token=hf_token) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( alt_model, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, device_map="auto" if device.type == "cuda" else None, low_cpu_mem_usage=True, token=hf_token ) print(f"โœ… Alternative model {alt_model} loaded successfully!") print("โš ๏ธ Note: Using alternative model - AR-Diffusion features adapted for demo") return tokenizer, model, device except Exception as alt_e: print(f"โŒ {alt_model} failed: {alt_e}") continue # Final fallback print("๐Ÿ”„ Using final fallback model...") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( "distilgpt2", torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, device_map="auto" if device.type == "cuda" else None, low_cpu_mem_usage=True ) print("โœ… Final fallback model loaded successfully!") print("โš ๏ธ Note: Using basic model - AR-Diffusion features adapted for demo") return tokenizer, model, device def cleanup_memory(): """Clean up GPU memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() @spaces.GPU def chat_function(message, history, mode, progress=gr.Progress()): """Main chat function with @spaces.GPU decorator, progress tracking, and speed display""" if not message.strip(): return history, "", "" try: # Load model (this will run on GPU when GPU is allocated) progress(0.05) tok, mod, dev = load_model() # Create appropriate generator if mode == "Quality (Slower, Better)": generator = QualityGenerator(tok, mod, dev) progress(0.1) else: generator = SpeedGenerator(tok, mod, dev) progress(0.1) # Generate response with progress callback def progress_callback(pct, status_msg): progress(pct) response, stats = generator.generate(message, progress_callback) progress(1.0) # Create performance info perf_info = f"""**โšก Performance Stats:** - **Mode:** {stats['mode']} - **Generation Time:** {stats['generation_time']:.2f}s - **Tokens Replaced:** {stats['tokens_replaced']} - **Speed:** {stats['tokens_per_second']:.1f} tokens/sec - **Words Generated:** {stats['word_count']} words - **Words/Second:** {stats['words_per_second']:.1f} - **Steps:** {stats['steps']}""" # Update history with proper message format history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) # Cleanup memory for Zero GPU efficiency cleanup_memory() return history, "", perf_info except Exception as e: error_msg = f"Error: {str(e)}" history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": error_msg}) cleanup_memory() return history, "", f"**โŒ Error occurred during generation**" def clear_chat(): """Clear chat history and cleanup memory""" cleanup_memory() return [], "" # Create Gradio interface def create_interface(): with gr.Blocks( title="AR-Diffusion Chat - Experimental Model", theme=gr.themes.Soft(), css=""" .warning-box { background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 10px; margin: 10px 0; } """ ) as interface: gr.HTML("""

๐Ÿงช AR-Diffusion Chat Interface

โš ๏ธ EXPERIMENTAL MODEL โš ๏ธ

This is an experimental AR-Diffusion model. Results may vary and the model is still under development.

๐Ÿ”ฅ Powered by Zero GPU with @spaces.GPU

Model: rootxhacker/llama-3B-diffusion-exp-fixed (LoRA Adapter)

๐Ÿ”‘ Requires HF_TOKEN for gated model access

""") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( [], elem_id="chatbot", height=500, show_label=False, type="messages" ) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here...", show_label=False, scale=9 ) send_btn = gr.Button("Send", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") with gr.Column(scale=1): gr.HTML("""

โš™๏ธ Mode Selection

Quality Mode: Slower but more coherent responses (~40 steps)

Speed Mode: Faster responses with decent quality (~10 steps)

๐Ÿ”ฅ GPU acceleration via @spaces.GPU

""") mode = gr.Radio( choices=["Quality (Slower, Better)", "Speed (Faster)"], value="Quality (Slower, Better)", label="Generation Mode" ) # Performance display perf_display = gr.Markdown( "**โšก Performance Stats:** *Generate a message to see stats*", elem_id="performance" ) gr.HTML("""

โ„น๏ธ About AR-Diffusion

This experimental model uses autoregressive diffusion for text generation, creating responses by iteratively denoising masked tokens.


Model: LoRA adapter trained for AR-Diffusion

Authentication: Requires HF_TOKEN for gated Llama model access

Note: This model is experimental and may produce unexpected results. If the specific model fails to load, alternative models will be used for demonstration.

""") # Event handlers def submit_message(message, history, mode): return chat_function(message, history, mode) send_btn.click( submit_message, inputs=[msg, chatbot, mode], outputs=[chatbot, msg, perf_display] ) msg.submit( submit_message, inputs=[msg, chatbot, mode], outputs=[chatbot, msg, perf_display] ) clear_btn.click( clear_chat, outputs=[chatbot, perf_display] ) return interface # Launch interface if __name__ == "__main__": demo = create_interface() demo.queue(max_size=20) # Important for Zero GPU demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True ) # Requirements.txt should include: # torch>=2.0.0 # transformers>=4.30.0 # gradio # numpy # accelerate # spaces # peft