File size: 6,458 Bytes
1d5b072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# --- 配置 ---
# 您上传到Hub的仓库ID (基础模型 + LoRA适配器)
hub_repo_id = "yxccai/text-style" 
# Qwen模型的基础模型名称 (与您微调时使用的基础模型一致)
# 例如: "Qwen/Qwen1.5-1.8B-Chat" 或 "Qwen/Qwen1.5-0.5B-Chat"
# 这个信息通常在您的LoRA适配器配置文件 (adapter_config.json) 中的 base_model_name_or_path 字段
# 您需要在这里明确指定它,因为我们要先加载基础模型
base_model_name = "Qwen/Qwen1.5-1.8B-Chat" # 假设您微调的是1.8B版本,请根据实际情况修改

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Gradio App: Using device: {device}")

# --- 加载模型和Tokenizer ---
print(f"Gradio App: Loading base model: {base_model_name}")
# 1. 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto", # 或者 torch.float16, torch.bfloat16
    # device_map="auto", # 在Spaces中,直接 .to(device) 可能更稳定
    trust_remote_code=True
    # quantization_config=... # 如果基础模型加载时需要量化,这里也要配置
)
base_model.to(device)

print(f"Gradio App: Loading tokenizer from: {hub_repo_id}")
# 2. 加载Tokenizer (从您上传的仓库,它应该包含了基础模型的tokenizer配置)
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    base_model.config.pad_token_id = tokenizer.eos_token_id


print(f"Gradio App: Loading LoRA adapter from: {hub_repo_id}")
# 3. 加载并应用LoRA适配器
# hub_repo_id 指向的是包含LoRA适配器权重 (adapter_model.bin) 和配置 (adapter_config.json) 的仓库
model = PeftModel.from_pretrained(base_model, hub_repo_id)
# model.to(device) # base_model 已经 to(device) 了,PeftModel会继承

# (可选) 如果希望合并权重以简化,但会占用更多内存/磁盘
# print("Gradio App: Merging LoRA adapter...")
# model = model.merge_and_unload()
# print("Gradio App: LoRA adapter merged.")

model.eval() # 设置为评估模式
print("Gradio App: Model and tokenizer loaded successfully.")

# --- 推理函数 ---
def chat(input_text):
    print(f"Gradio App: Received input: {input_text}")
    # 构建符合Qwen Chat模板的输入
    messages = [
        {"role": "system", "content": "你是一个文本风格转换助手。请严格按照要求,仅将以下书面文本转换为自然、口语化的简洁表达方式,不要添加任何额外的解释、扩展信息或重复原文。"},
        {"role": "user", "content": input_text}
    ]

    # 使用 apply_chat_template
    # 注意:Hugging Face Spaces环境中的transformers版本可能与Colab不同
    # 确保 apply_chat_template 的用法与您测试时一致
    try:
        prompt = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True # 推理时需要模型知道何时开始生成
        )
    except Exception as e:
        print(f"Error applying chat template: {e}")
        # 回退到一个简单的拼接方式,但这可能不是最优的
        prompt = messages[0]["content"] + "\n" + messages[1]["content"] + "\n" + tokenizer.eos_token # 或者其他适合的格式


    print(f"Gradio App: Formatted prompt for model:\n{prompt}")

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device) # 调整max_length

    generated_ids = model.generate(
        **inputs,
        max_new_tokens=2048, # 控制输出长度
        num_beams=1,        # 可以尝试增加
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )

    # 解码生成的token IDs
    # generated_ids[0] 包含了输入提示和模型生成的部分
    full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False) # 保留特殊token以帮助分割
    print(f"Gradio App: Full generated sequence:\n{full_generated_text}")

    # 从完整序列中提取assistant的回复
    assistant_marker_start = "<|im_start|>assistant" # Qwen的标记

    if assistant_marker_start in full_generated_text:
        parts = full_generated_text.split(assistant_marker_start)
        if len(parts) > 1:
            assistant_reply = parts[-1].strip()
            # 移除可能的结束标记,如 <|im_end|> 或 eos_token
            if assistant_reply.endswith(tokenizer.eos_token):
                assistant_reply = assistant_reply[:-len(tokenizer.eos_token)].strip()
            elif "<|im_end|>" in assistant_reply: # Qwen的聊天模板使用 <|im_end|>
                 assistant_reply = assistant_reply.split("<|im_end|>")[0].strip()
            result = assistant_reply
        else:
            result = "模型未能生成assistant标记后的回复。"
    else:
        # 如果找不到 assistant 标记,尝试从原始prompt之后提取
        # 这需要原始prompt的token数量
        # 另一种简单方式是直接解码去除特殊token的生成部分,但这可能包含一些模板残留
        result = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True).strip()
        if not result: # 如果这种方式结果为空,可能解码时skip_special_tokens去除了所有
             result = "模型输出格式不符合预期,未能提取有效回复。"

    print(f"Gradio App: Extracted result: {result}")
    return result

# --- 创建Gradio界面 ---
iface = gr.Interface(
    fn=chat,
    inputs=gr.Textbox(lines=5, label="输入书面文本 (Input Formal Text)"),
    outputs=gr.Textbox(lines=5, label="输出口语化文本 (Output Casual Text)"),
    title="文本风格转换器 (Text Style Converter)",
    description="输入一段书面化的中文文本,模型会尝试将其转换为更自然、口语化的表达方式。由Qwen模型微调。",
    examples=[
        ["乙醇的检测方法包括以下几项: 1. 酸碱度检查:取20ml乙醇加20ml水,加2滴酚酞指示剂应无色,再加1ml 0.01mol/L氢氧化钠应显粉红色."],
        ["本公司今日发布了最新的财务业绩报告,数据显示本季度利润实现了显著增长。"]
    ]
)

if __name__ == "__main__":
    iface.launch()