# app.py import torch import gradio as gr from transformers import AutoModelForCausalLM, BitsAndBytesConfig from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM # ==== 模型設定 ==== model_path = "deepseek-ai/deepseek-vl-7b-chat" # BitsAndBytes 4-bit 量化設定 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True ) # 載入 processor 和 tokenizer vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer # 載入模型 vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( model_path, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ).eval() # ==== 對話歷史 ==== chat_history = [] # ==== 文字+圖片推理函式 ==== def chat_with_image(image, user_message): global chat_history try: # 建立對話內容 conversation = chat_history.copy() conversation.append({ "role": "User", "content": "" + user_message, "images": [image] if image else [] }) conversation.append({"role": "Assistant", "content": ""}) # 準備輸入 prepare_inputs = vl_chat_processor( conversations=conversation, images=[image] if image else [], force_batchify=True ).to(vl_gpt.device) # 轉成 dict,並正確處理 dtype prepare_inputs = {k: getattr(prepare_inputs, k) for k in prepare_inputs.__dataclass_fields__.keys()} new_inputs = {} for k, v in prepare_inputs.items(): if torch.is_tensor(v): if k in ["input_ids", "labels"]: new_inputs[k] = v.to(torch.long) else: new_inputs[k] = v.to(torch.float16) else: new_inputs[k] = v prepare_inputs = new_inputs # 取得 embeddings inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) # 生成回答 outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs["attention_mask"], pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=128, do_sample=False, use_cache=True ) # 解碼 answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) # 更新歷史 chat_history.append((user_message, answer)) return answer, chat_history except Exception as e: return f"Error: {str(e)}", chat_history def reset_chat(): global chat_history chat_history = [] return "", [] # ==== Gradio Web UI ==== with gr.Blocks() as demo: gr.Markdown("# DeepSeek-VL-7B-Chat Demo (4-bit, float16)") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox(lines=2, placeholder="Ask about the image...") with gr.Row(): submit_btn = gr.Button("Submit") reset_btn = gr.Button("Reset Chat") output_text = gr.Textbox(label="Answer") chat_display = gr.Chatbot(label="Chat History") submit_btn.click(chat_with_image, inputs=[image_input, text_input], outputs=[output_text, chat_display]) reset_btn.click(reset_chat, inputs=[], outputs=[output_text, chat_display]) if __name__ == "__main__": demo.launch()