| | import os |
| | import argparse |
| | from pathlib import Path |
| | import json |
| | from typing import Optional |
| |
|
| | import torch |
| | from PIL import Image |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | import gradio as gr |
| | from model import MultiModalDenseTransformer |
| | from continual_learning import UnifiedMultiModalPreprocessor |
| |
|
| | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
| |
|
| | from torchvision import transforms |
| | image_transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | class ModelInference: |
| | def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): |
| | self.device = torch.device(device) |
| | print(f"Using device: {self.device}") |
| | print(f"Loading tokenizer: {tokenizer_name}...") |
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True) |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| |
|
| | if config_path and Path(config_path).exists(): |
| | with open(config_path, 'r') as f: |
| | self.config = json.load(f) |
| | else: |
| | self.config = { |
| | 'model_dim': 1536, |
| | 'vocab_size': len(self.tokenizer), |
| | 'n_layers': 12, |
| | 'n_heads': 12, |
| | 'n_kv_heads': 4, |
| | 'head_dim': None, |
| | 'max_seq_len': 512, |
| | 'dropout': 0.0, |
| | 'use_moe': False, |
| | 'use_adapter': False, |
| | 'use_lora': False, |
| | 'rope_scaling_type': "yarn", |
| | 'use_multimodal_fusion': False, |
| | 'use_contrastive': False |
| | } |
| |
|
| | |
| | print("Initializing model architecture...") |
| | self.model = MultiModalDenseTransformer(**self.config) |
| | self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim']) |
| |
|
| | print(f"Loading checkpoint from {checkpoint_path}...") |
| | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| | state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint |
| |
|
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith('module.'): |
| | new_state_dict[k[7:]] = v |
| | else: |
| | new_state_dict[k] = v |
| |
|
| | missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) |
| | if missing: |
| | print(f"Warning: Missing keys: {len(missing)}") |
| | if unexpected: |
| | print(f"Warning: Unexpected keys: {len(unexpected)}") |
| |
|
| | self.model.to(self.device) |
| | self.preprocessor.to(self.device) |
| | self.model.eval() |
| | print("Model loaded successfully!") |
| | print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M") |
| |
|
| | @torch.no_grad() |
| | def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str: |
| | formatted_prompt = f"Instruction: {prompt}\nResponse:" |
| | inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
| | input_ids = inputs['input_ids'].to(self.device) |
| |
|
| | input_data = {'segments': []} |
| | if image is not None: |
| | try: |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | image_tensor = image_transform(image).unsqueeze(0).to(self.device) |
| | mod_segments = self.preprocessor.process_batch(image_tensor, 'image') |
| | for seg in mod_segments: |
| | input_data['segments'].append(seg) |
| | except Exception as e: |
| | print(f"Warning: Image processing skipped due to error: {e}") |
| |
|
| | input_data['segments'].append({ |
| | 'type': 'text', |
| | 'data': input_ids, |
| | 'modality_id': 0 |
| | }) |
| |
|
| | try: |
| | generated_ids = self.model.generate( |
| | input_data, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | do_sample=True, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id |
| | ) |
| |
|
| | full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | |
| | if "Response:" in full_output: |
| | answer = full_output.split("Response:")[-1].strip() |
| | else: |
| | answer = full_output |
| |
|
| | stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"] |
| | for sw in stop_words: |
| | if sw in answer: |
| | answer = answer.split(sw)[0].strip() |
| |
|
| | |
| | lines = answer.split('\n') |
| | if len(lines) > 0 and prompt.lower() in lines[0].lower(): |
| | answer = "\n".join(lines[1:]).strip() |
| | return answer |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return f"Error: {e}" |
| |
|
| | def build_ui(model_instance): |
| | with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css=""" |
| | .gradio-container { max-width: 900px; margin: auto; } |
| | """) as demo: |
| | gr.Markdown("## 多模态在线推理(文本 + 图片)") |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5) |
| | img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)") |
| | btn = gr.Button("生成 (Generate)") |
| | with gr.Column(scale=2): |
| | max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128) |
| | temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7) |
| | top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40) |
| | top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9) |
| | rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1) |
| | status = gr.Textbox(label="Status", value="Ready", interactive=False) |
| | output = gr.Textbox(label="Output", lines=12, interactive=False) |
| |
|
| | def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v): |
| | if not prompt or str(prompt).strip() == "": |
| | return "", "请输入 Prompt", "" |
| | status_msg = "Generating..." |
| | |
| | out = model_instance.generate_text(prompt=prompt, |
| | max_new_tokens=int(max_tokens_v), |
| | temperature=float(temp_v), |
| | top_k=int(topk_v), |
| | top_p=float(topp_v), |
| | repetition_penalty=float(rep_v), |
| | image=image) |
| | return out, "Done", "" |
| |
|
| | btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()]) |
| |
|
| | demo.launch(share=True) |
| |
|
| | return demo |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt") |
| | parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct") |
| | parser.add_argument("--config", type=str, default=None) |
| | parser.add_argument("--port", type=int, default=7860) |
| | parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True) |
| | args = parser.parse_args() |
| |
|
| | if not Path(args.checkpoint).exists(): |
| | possible = list(Path("checkpoints/pretrain").glob("step_*.pt")) |
| | if possible: |
| | args.checkpoint = str(possible[-1]) |
| | print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}") |
| | else: |
| | raise FileNotFoundError(f"找不到检查点: {args.checkpoint}") |
| |
|
| | global model_instance |
| | model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config) |
| |
|
| | demo = build_ui(model_instance) |
| | demo.launch(server_port=args.port, share=args.share) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|