text-style-api / app.py
yxccai's picture
Update app.py
7460c5d verified
# ────────────────────────────────────────────────────────────────────────────────
# app.py (CPU-only 版:先加载 float32 基座 LLaMA-8B,再叠入 LoRA Adapter)
# ────────────────────────────────────────────────────────────────────────────────
import gradio as gr
import torch
import gc
import os
from transformers import AutoTokenizer, LlamaForCausalLM
from peft import PeftModel
# ─────────────────────── 1. 释放可能的显存/内存 ───────────────────────
# 对于 CPU-only,可以留着,也不会报错
torch.cuda.empty_cache()
gc.collect()
# ─────────────────────── 2. 配置区域 ───────────────────────
# (A)Adapter 仓库 ID:LoRA 权重所在的 Hugging Face Repo
# 这个仓库里只有 adapter_model.safetensors + adapter_config.json + tokenizer 文件
ADAPTER_REPO = "yxccai/text-style-converter"
# (B)基座模型 ID(去掉了 -bnb-4bit 后缀,改用 float32 版)
# 原 adapter_config.json 里提到的 "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
# 在 CPU-only 环境下不能加载 4bit bitsandbytes,所以我们要改为:
# "unsloth/deepseek-r1-distill-llama-8b"
# 如果您本地没有这个仓库,可以换成“decapoda-research/llama-7b-hf”或其他您能在 CPU 上跑通的模型。
BASE_MODEL_ID = "unsloth/deepseek-r1-distill-llama-8b"
# 全局变量:Tokenizer + Model
tokenizer = None
model = None
# ─────────────────────── 3. 加载模型的函数 ───────────────────────
def load_model():
"""
CPU-only 逻辑:
1. 先从 Adapter 仓库加载 Tokenizer(里面有 tokenizer.json 等文件)。
2. 再用 LlamaForCausalLM 从 float32 版基座模型加载到 CPU。
3. 然后用 PeftModel.from_pretrained(...) 将 LoRA Adapter 权重叠加到基座上。
"""
global tokenizer, model
# 如果 tokenizer/model 还未加载,则执行加载逻辑
if tokenizer is None or model is None:
try:
# ── 3.1 加载 Tokenizer ──
print("正在加载 Tokenizer(来自 LoRA 仓库)…")
tokenizer = AutoTokenizer.from_pretrained(
ADAPTER_REPO,
trust_remote_code=True,
use_fast=False,
)
# 如果 pad_token 不存在,就用 eos_token 代替
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ── 3.2 加载基座模型(LLaMA float32 → CPU) ──
print(f"正在加载基座模型:{BASE_MODEL_ID} (float32 → CPU)…")
# 注意:这里用 torch_dtype=torch.float32, device_map="cpu"。如果 Model 太大、内存不足,会 OOM。
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True, # 尽量启用低内存占用模式
trust_remote_code=True,
)
print("→ 基座模型加载完成。(注意检查是否被系统 OOM)")
# ── 3.3 用 PeftModel 叠加 LoRA Adapter ──
print(f"正在叠加 LoRA Adapter:{ADAPTER_REPO}…")
model = PeftModel.from_pretrained(
base_model,
ADAPTER_REPO,
device_map="cpu", # CPU-only 环境
torch_dtype=torch.float32, # 同样使用 float32
)
print("→ LoRA Adapter 已叠加成功。")
# (可选)不想更新基座所有参数时,把 base_model 的参数都冻结:
# model.eval()
# for param in model.base_model.parameters():
# param.requires_grad = False
except Exception as e:
import traceback
traceback.print_exc()
print(f"模型加载失败: {str(e)}")
return False
return True
# ─────────────────────── 4. 文本生成函数 ───────────────────────
def convert_text_style(input_text: str) -> str:
"""
输入一句书面化/技术性的中文,让模型把它转换成自然、口语化的表达方式。
"""
if not input_text or input_text.strip() == "":
return "请输入要转换的文本。"
# 确保模型已加载
if not load_model():
return "模型加载失败,请稍后重试。"
try:
# 拼一个简单的 Prompt
prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。
### 输入文本:
{input_text}
### 输出文本:
"""
# 分词 & 转 torch.Tensor
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True,
)
# 全部放到 CPU 上
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# 生成
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=256,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=2,
num_return_sequences=1,
)
# 解码并抽取结果
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### 输出文本:" in full_text:
return full_text.split("### 输出文本:")[-1].strip()
return full_text[len(prompt) :].strip()
except Exception as e:
import traceback
traceback.print_exc()
return f"生成过程中出现错误: {str(e)}"
# ─────────────────────── 5. Gradio 界面配置 ───────────────────────
iface = gr.Interface(
fn=convert_text_style,
inputs=gr.Textbox(
label="输入文本", placeholder="请输入需要转换为口语化的书面文本...", lines=3
),
outputs=gr.Textbox(label="输出文本", lines=4),
title="中文文本风格转换API",
description="将书面化、技术性文本转换为自然、口语化表达",
examples=[
["乙醇的检测方法包括酸碱度检查。"],
["本品为薄膜衣片,除去包衣后显橙红色。"],
],
cache_examples=False,
flagging_mode="never",
)
if __name__ == "__main__":
print("启动 Gradio 应用…")
# 纯 CPU 环境下,server_name 可以保持默认 "0.0.0.0",port 也是 7860
iface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=False)