import spaces  # 必须在最顶部导入
import gradio as gr
import os

# 获取 Hugging Face 访问令牌
hf_token = os.getenv("HF_API_TOKEN")

# 定义基础模型名称
base_model_name = "WooWoof/WooWoof_AI_Vision16Bit"

# 定义 adapter 模型名称
adapter_model_name = "WooWoof/WooWoof_AI_Vision16Bit"

# 定义全局变量用于缓存模型和分词器
model = None
tokenizer = None

# 定义提示生成函数
def generate_prompt(instruction, input_text=""):
    if input_text:
        prompt = f"""### Instruction:
{instruction}
### Input:
{input_text}
### Response:
"""
    else:
        prompt = f"""### Instruction:
{instruction}
### Response:
"""
    return prompt

# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU(duration=40)  # 建议将 duration 增加到 120
def generate_response(instruction, input_text):
    global model, tokenizer

    if model is None:
        print("开始加载模型...")
        # 检查 bitsandbytes 是否已安装
        import importlib.util
        if importlib.util.find_spec("bitsandbytes") is None:
            import subprocess
            subprocess.call(["pip", "install", "--upgrade", "bitsandbytes"])

        try:
            # 在函数内部导入需要 GPU 的库
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

            from peft import PeftModel

            # 创建量化配置
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )

            # 加载分词器
            tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=hf_token)
            print("分词器加载成功。")

            # 加载基础模型
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                quantization_config=bnb_config,
                device_map="auto",
                use_auth_token=hf_token,
                trust_remote_code=True
            )
            print("基础模型加载成功。")

            # 加载适配器模型
            model = PeftModel.from_pretrained(
                base_model,
                 adapter_model_name,
                torch_dtype=torch.float16,
                use_auth_token=hf_token
            )
            print("适配器模型加载成功。")

            # 设置 pad_token
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = tokenizer.pad_token_id

            # 切换到评估模式
            model.eval()
            print("模型已切换到评估模式。")
        except Exception as e:
            print("加载模型时出错:", e)
            raise e
    else:
        # 在函数内部导入需要的库
        import torch

    # 检查 model 和 tokenizer 是否已正确加载
    if model is None or tokenizer is None:
        print("模型或分词器未正确加载。")
        raise ValueError("模型或分词器未正确加载。")

    # 生成提示
    prompt = generate_prompt(instruction, input_text)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask"),
            max_new_tokens=128,
            temperature=0.7,
            top_p=0.95,
            do_sample=True,
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    return response

# 创建 Gradio 接口
iface = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(lines=2, placeholder="Instruction", label="Instruction"),
    ],
    outputs="text",
    title="WooWoof AI Vision",
    allow_flagging="never"
)

# 启动 Gradio 接口
iface.launch(share=True)