fnmodel / inference_app.py
aeb56
Transform Space into professional inference UI for fine-tuned model
5e458c4
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Model configuration
MODEL_NAME = "optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune"
MODEL_DESCRIPTION = """
# ๐Ÿš€ Kimi Linear 48B A3B Instruct - Fine-tuned
A professionally fine-tuned version of Moonshot AI's Kimi-Linear-48B-A3B-Instruct model using QLoRA.
**Model Details:**
- **Base Model:** moonshotai/Kimi-Linear-48B-A3B-Instruct
- **Parameters:** 48 Billion
- **Fine-tuning Method:** QLoRA (Quantized Low-Rank Adaptation)
- **Training Focus:** Attention layers (q_proj, k_proj, v_proj, o_proj)
- **Architecture:** Mixture of Experts (MoE) Transformer
"""
# Check GPU availability
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
total_vram = sum(torch.cuda.get_device_properties(i).total_memory / 1024**3 for i in range(num_gpus))
logger.info(f"๐ŸŽฎ {num_gpus} GPU(s) detected with {total_vram:.1f}GB total VRAM")
else:
logger.warning("โš ๏ธ No GPUs detected - running on CPU (will be slow)")
class ModelInference:
def __init__(self):
self.model = None
self.tokenizer = None
self.is_loaded = False
def load_model(self, progress=gr.Progress()):
"""Load the model and tokenizer"""
if self.is_loaded:
return "โœ… Model already loaded"
try:
progress(0.2, desc="Loading tokenizer...")
logger.info(f"Loading tokenizer from: {MODEL_NAME}")
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
progress(0.4, desc="Loading model (this may take several minutes)...")
logger.info(f"Loading model from: {MODEL_NAME}")
# Configure for multi-GPU
num_gpus = torch.cuda.device_count()
max_memory = {}
if num_gpus > 0:
for i in range(num_gpus):
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
max_memory[i] = f"{int(gpu_memory - 3)}GB"
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
max_memory=max_memory if max_memory else None,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
self.model.eval()
self.is_loaded = True
progress(1.0, desc="Model loaded!")
logger.info("โœ… Model loaded successfully")
# Get model info
total_params = sum(p.numel() for p in self.model.parameters())
model_size = (total_params * 2) / 1024**3 # bfloat16 = 2 bytes
info_msg = f"""
โœ… **Model Loaded Successfully!**
**Model Information:**
- Model: `{MODEL_NAME}`
- Parameters: {total_params:,}
- Size: ~{model_size:.1f} GB (bfloat16)
- Device: {"Multi-GPU" if num_gpus > 1 else "Single GPU" if num_gpus == 1 else "CPU"}
**You can now start chatting below!** ๐Ÿ‘‡
"""
return info_msg
except Exception as e:
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
self.is_loaded = False
return f"โŒ **Failed to load model:**\n\n{str(e)}"
def generate_response(
self,
message,
history,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty,
):
"""Generate a response from the model"""
if not self.is_loaded:
return "โŒ Please load the model first using the 'Load Model' button above."
try:
# Build conversation context
conversation = []
# Add system prompt if provided
if system_prompt.strip():
conversation.append(f"System: {system_prompt.strip()}")
# Add chat history
for human, assistant in history:
conversation.append(f"User: {human}")
if assistant:
conversation.append(f"Assistant: {assistant}")
# Add current message
conversation.append(f"User: {message}")
conversation.append("Assistant:")
# Format prompt
prompt = "\n".join(conversation)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True if temperature > 0 else False,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant's response (everything after the last "Assistant:")
if "Assistant:" in response:
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
logger.error(f"Generation failed: {str(e)}", exc_info=True)
return f"โŒ **Generation failed:**\n\n{str(e)}"
# Initialize inference
inferencer = ModelInference()
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="Kimi 48B Fine-tuned - Inference") as demo:
gr.Markdown(MODEL_DESCRIPTION)
# GPU Info
if torch.cuda.is_available():
gpu_info = f"### ๐ŸŽฎ Hardware: {torch.cuda.device_count()}x {torch.cuda.get_device_name(0)} ({total_vram:.1f}GB total VRAM)"
else:
gpu_info = "### โš ๏ธ Running on CPU (no GPU detected)"
gr.Markdown(gpu_info)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
load_btn = gr.Button("๐Ÿš€ Load Model", variant="primary", size="lg")
load_status = gr.Markdown("**Status:** Model not loaded. Click 'Load Model' to start.")
gr.Markdown("### โš™๏ธ Generation Settings")
system_prompt = gr.Textbox(
label="System Prompt (Optional)",
placeholder="You are a helpful AI assistant...",
lines=3,
value=""
)
max_new_tokens = gr.Slider(
minimum=50,
maximum=4096,
value=1024,
step=1,
label="Max New Tokens",
info="Maximum length of generated response"
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.05,
label="Temperature",
info="Higher = more creative, Lower = more focused"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P (Nucleus Sampling)",
info="Probability threshold for token selection"
)
top_k = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Top K",
info="Number of top tokens to consider (0 = disabled)"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05,
label="Repetition Penalty",
info="Penalty for repeating tokens"
)
with gr.Column(scale=2):
gr.Markdown("### ๐Ÿ’ฌ Chat Interface")
chatbot = gr.Chatbot(
height=500,
label="Conversation",
show_copy_button=True,
avatar_images=["๐Ÿ‘ค", "๐Ÿค–"]
)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=3,
scale=4
)
send_btn = gr.Button("๐Ÿ“ค Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat")
retry_btn = gr.Button("๐Ÿ”„ Retry Last")
gr.Markdown("""
### ๐Ÿ“ Usage Tips:
- First, click **"Load Model"** to initialize the model (takes 2-5 minutes)
- Use the **System Prompt** to set the assistant's behavior
- Adjust **Temperature** for creativity (0.7-1.0 recommended)
- Lower **Top P** for more focused responses
- Clear chat to start a new conversation
""")
# Event handlers
load_btn.click(
fn=inferencer.load_model,
outputs=load_status
)
def user_message(user_msg, history):
return "", history + [[user_msg, None]]
def bot_response(history, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
user_msg = history[-1][0]
bot_msg = inferencer.generate_response(
user_msg,
history[:-1],
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty
)
history[-1][1] = bot_msg
return history
# Send message
msg.submit(
user_message,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
bot_response,
[chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
chatbot
)
send_btn.click(
user_message,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
bot_response,
[chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
chatbot
)
# Clear chat
clear_btn.click(lambda: None, None, chatbot, queue=False)
# Retry last message
def retry_last(history):
if history:
history[-1][1] = None
return history
retry_btn.click(
retry_last,
chatbot,
chatbot,
queue=False
).then(
bot_response,
[chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
chatbot
)
gr.Markdown("""
---
**Model:** [optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune](https://huggingface.co/optiviseapp/kimi-linear-48b-a3b-instruct-fine-tune)
**Base Model:** [moonshotai/Kimi-Linear-48B-A3B-Instruct](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct)
Fine-tuned with โค๏ธ using QLoRA
""")
# Launch
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)