Spaces:
Paused
Paused
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import logging | |
| from datetime import datetime | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| MODEL_NAME = "optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune" | |
| MODEL_DESCRIPTION = """ | |
| # ๐ Kimi Linear 48B A3B Instruct - Fine-tuned | |
| A professionally fine-tuned version of Moonshot AI's Kimi-Linear-48B-A3B-Instruct model using QLoRA. | |
| **Model Details:** | |
| - **Base Model:** moonshotai/Kimi-Linear-48B-A3B-Instruct | |
| - **Parameters:** 48 Billion | |
| - **Fine-tuning Method:** QLoRA (Quantized Low-Rank Adaptation) | |
| - **Training Focus:** Attention layers (q_proj, k_proj, v_proj, o_proj) | |
| - **Architecture:** Mixture of Experts (MoE) Transformer | |
| """ | |
| # Check GPU availability | |
| if torch.cuda.is_available(): | |
| num_gpus = torch.cuda.device_count() | |
| total_vram = sum(torch.cuda.get_device_properties(i).total_memory / 1024**3 for i in range(num_gpus)) | |
| logger.info(f"๐ฎ {num_gpus} GPU(s) detected with {total_vram:.1f}GB total VRAM") | |
| else: | |
| logger.warning("โ ๏ธ No GPUs detected - running on CPU (will be slow)") | |
| class ModelInference: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.is_loaded = False | |
| def load_model(self, progress=gr.Progress()): | |
| """Load the model and tokenizer""" | |
| if self.is_loaded: | |
| return "โ Model already loaded" | |
| try: | |
| progress(0.2, desc="Loading tokenizer...") | |
| logger.info(f"Loading tokenizer from: {MODEL_NAME}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True | |
| ) | |
| progress(0.4, desc="Loading model (this may take several minutes)...") | |
| logger.info(f"Loading model from: {MODEL_NAME}") | |
| # Configure for multi-GPU | |
| num_gpus = torch.cuda.device_count() | |
| max_memory = {} | |
| if num_gpus > 0: | |
| for i in range(num_gpus): | |
| gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 | |
| max_memory[i] = f"{int(gpu_memory - 3)}GB" | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| max_memory=max_memory if max_memory else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| self.model.eval() | |
| self.is_loaded = True | |
| progress(1.0, desc="Model loaded!") | |
| logger.info("โ Model loaded successfully") | |
| # Get model info | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| model_size = (total_params * 2) / 1024**3 # bfloat16 = 2 bytes | |
| info_msg = f""" | |
| โ **Model Loaded Successfully!** | |
| **Model Information:** | |
| - Model: `{MODEL_NAME}` | |
| - Parameters: {total_params:,} | |
| - Size: ~{model_size:.1f} GB (bfloat16) | |
| - Device: {"Multi-GPU" if num_gpus > 1 else "Single GPU" if num_gpus == 1 else "CPU"} | |
| **You can now start chatting below!** ๐ | |
| """ | |
| return info_msg | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}", exc_info=True) | |
| self.is_loaded = False | |
| return f"โ **Failed to load model:**\n\n{str(e)}" | |
| def generate_response( | |
| self, | |
| message, | |
| history, | |
| system_prompt, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| ): | |
| """Generate a response from the model""" | |
| if not self.is_loaded: | |
| return "โ Please load the model first using the 'Load Model' button above." | |
| try: | |
| # Build conversation context | |
| conversation = [] | |
| # Add system prompt if provided | |
| if system_prompt.strip(): | |
| conversation.append(f"System: {system_prompt.strip()}") | |
| # Add chat history | |
| for human, assistant in history: | |
| conversation.append(f"User: {human}") | |
| if assistant: | |
| conversation.append(f"Assistant: {assistant}") | |
| # Add current message | |
| conversation.append(f"User: {message}") | |
| conversation.append("Assistant:") | |
| # Format prompt | |
| prompt = "\n".join(conversation) | |
| # Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True if temperature > 0 else False, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract assistant's response (everything after the last "Assistant:") | |
| if "Assistant:" in response: | |
| response = response.split("Assistant:")[-1].strip() | |
| return response | |
| except Exception as e: | |
| logger.error(f"Generation failed: {str(e)}", exc_info=True) | |
| return f"โ **Generation failed:**\n\n{str(e)}" | |
| # Initialize inference | |
| inferencer = ModelInference() | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Kimi 48B Fine-tuned - Inference") as demo: | |
| gr.Markdown(MODEL_DESCRIPTION) | |
| # GPU Info | |
| if torch.cuda.is_available(): | |
| gpu_info = f"### ๐ฎ Hardware: {torch.cuda.device_count()}x {torch.cuda.get_device_name(0)} ({total_vram:.1f}GB total VRAM)" | |
| else: | |
| gpu_info = "### โ ๏ธ Running on CPU (no GPU detected)" | |
| gr.Markdown(gpu_info) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| load_btn = gr.Button("๐ Load Model", variant="primary", size="lg") | |
| load_status = gr.Markdown("**Status:** Model not loaded. Click 'Load Model' to start.") | |
| gr.Markdown("### โ๏ธ Generation Settings") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt (Optional)", | |
| placeholder="You are a helpful AI assistant...", | |
| lines=3, | |
| value="" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=4096, | |
| value=1024, | |
| step=1, | |
| label="Max New Tokens", | |
| info="Maximum length of generated response" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P (Nucleus Sampling)", | |
| info="Probability threshold for token selection" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top K", | |
| info="Number of top tokens to consider (0 = disabled)" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.05, | |
| label="Repetition Penalty", | |
| info="Penalty for repeating tokens" | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### ๐ฌ Chat Interface") | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| label="Conversation", | |
| show_copy_button=True, | |
| avatar_images=["๐ค", "๐ค"] | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=3, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("๐ค Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("๐๏ธ Clear Chat") | |
| retry_btn = gr.Button("๐ Retry Last") | |
| gr.Markdown(""" | |
| ### ๐ Usage Tips: | |
| - First, click **"Load Model"** to initialize the model (takes 2-5 minutes) | |
| - Use the **System Prompt** to set the assistant's behavior | |
| - Adjust **Temperature** for creativity (0.7-1.0 recommended) | |
| - Lower **Top P** for more focused responses | |
| - Clear chat to start a new conversation | |
| """) | |
| # Event handlers | |
| load_btn.click( | |
| fn=inferencer.load_model, | |
| outputs=load_status | |
| ) | |
| def user_message(user_msg, history): | |
| return "", history + [[user_msg, None]] | |
| def bot_response(history, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
| user_msg = history[-1][0] | |
| bot_msg = inferencer.generate_response( | |
| user_msg, | |
| history[:-1], | |
| system_prompt, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ) | |
| history[-1][1] = bot_msg | |
| return history | |
| # Send message | |
| msg.submit( | |
| user_message, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
| chatbot | |
| ) | |
| send_btn.click( | |
| user_message, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
| chatbot | |
| ) | |
| # Clear chat | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| # Retry last message | |
| def retry_last(history): | |
| if history: | |
| history[-1][1] = None | |
| return history | |
| retry_btn.click( | |
| retry_last, | |
| chatbot, | |
| chatbot, | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
| chatbot | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Model:** [optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune](https://huggingface.co/optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune) | |
| **Base Model:** [moonshotai/Kimi-Linear-48B-A3B-Instruct](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct) | |
| Fine-tuned with โค๏ธ using QLoRA | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |