怀羽 commited on
Commit
a67e7e4
·
1 Parent(s): 4dca14c

change to hf decode

Browse files
Files changed (2) hide show
  1. app.py +105 -29
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,30 +1,80 @@
1
  import gradio as gr
2
- from vllm import LLM, SamplingParams
 
 
 
3
 
4
  # --------------------------------------------------------------------------
5
  # 1. 配置和加载模型 (在应用启动时执行一次)
6
  # --------------------------------------------------------------------------
7
 
8
- # !! 重要 !! -> 将此处的 "your-org/your-algharb-model" 替换成你在 Hugging Face Hub 上的模型ID
 
9
  model_id = "AIDC-AI/Marco-MT-Algharb"
 
 
 
 
10
 
11
- print(f"正在加载模型: {model_id}...")
 
 
 
12
  try:
13
- llm = LLM(model=model_id)
14
- print("模型加载成功!")
 
 
 
15
  except Exception as e:
16
- print(f"模型加载失败: {e}")
17
- llm = None # 标记模型加载失败
18
-
19
- # 定义采样参数
20
- sampling_params = SamplingParams(
21
- n=1,
22
- temperature=0.001,
23
- top_p=0.001,
24
- max_tokens=512,
25
- )
26
-
27
- # 语言代码到全名的映射
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  source_lang_name_map = {
29
  "en": "english",
30
  "ja": "japanese",
@@ -46,14 +96,14 @@ target_lang_name_map = {
46
  "de": "german",
47
  }
48
  # --------------------------------------------------------------------------
49
- # 2. 定义核心翻译函数
50
  # --------------------------------------------------------------------------
51
  def translate(source_text, source_lang_code, target_lang_code):
52
  """
53
- 接收用户输入并返回翻译结果
54
  """
55
- if llm is None:
56
- return "错误:模型未能成功加载,请检查 Space 日志。"
57
 
58
  # 简单的输入验证
59
  if not source_text or not source_text.strip():
@@ -62,23 +112,50 @@ def translate(source_text, source_lang_code, target_lang_code):
62
  source_language_name = source_lang_name_map.get(source_lang_code, "the source language")
63
  target_language_name = target_lang_name_map.get(target_lang_code, "the target language")
64
 
 
65
  prompt = (
66
  f"Human: Please translate the following text into {target_language_name}: \n"
67
  f"{source_text}<|im_end|>\n"
68
  f"Assistant:"
69
  )
 
70
  print(prompt)
71
- outputs = llm.generate([prompt], sampling_params)
72
-
73
- generated_text = outputs[0].outputs[0].text.strip()
74
-
75
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # --------------------------------------------------------------------------
78
- # 3. 创建并配置 Gradio 界面 (修改版)
79
  # --------------------------------------------------------------------------
80
 
81
- # <--- 修改 1: 定义自定义 CSS 样式 (高级背景 + 正常字体) --->
82
  css = """
83
  /* --- 1. 整体背景 (改为更高级的浅灰蓝渐变) --- */
84
  .gradio-container {
@@ -188,7 +265,6 @@ with gr.Blocks(
188
  )
189
 
190
  # --- (新位置) 支持的语向卡片 ---
191
- # <--- 修改 3: 此处HTML将自动继承新的全局字体 --->
192
  gr.HTML(f"""
193
  <div style="color: #444; font-size: 16px; margin-top: 30px; padding: 20px 25px; background-color: #FFFFFF; border-radius: 15px; max-width: 900px; margin-left: auto; margin-right: auto; box-shadow: 0 4px 20px rgba(0,0,0,0.05);">
194
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+ import sys
5
+ import os
6
 
7
  # --------------------------------------------------------------------------
8
  # 1. 配置和加载模型 (在应用启动时执行一次)
9
  # --------------------------------------------------------------------------
10
 
11
+ # 确保这里是你的本地模型路径
12
+ # model_id = "/mnt/workspace/wanghao/model_saved/Marco-MT-WMT"
13
  model_id = "AIDC-AI/Marco-MT-Algharb"
14
+ # 将模型目录添加到 Python 路径 (修复 Qwen3ForCausalLM 导入问题)
15
+ if os.path.isdir(model_id):
16
+ sys.path.insert(0, model_id)
17
+ print(f"已将模型目录添加到 sys.path: {model_id}")
18
 
19
+ print(f"正在加载 Tokenizer: {model_id}...")
20
+ tokenizer = None
21
+ model = None
22
+ device = "cuda"
23
  try:
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ model_id,
26
+ trust_remote_code=True
27
+ )
28
+ print("Tokenizer 加载成功!")
29
  except Exception as e:
30
+ print(f"Tokenizer 加载失败: {e}")
31
+
32
+ if tokenizer:
33
+ print(f"正在加载模型: {model_id}...")
34
+ try:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ trust_remote_code=True
38
+ ).to(device).eval()
39
+
40
+ print("模型加载成功!")
41
+ except Exception as e:
42
+ print(f"模型加载失败: {e}")
43
+ model = None
44
+ else:
45
+ print("因 Tokenizer 加载失败,跳过模型加载。")
46
+ model = None
47
+
48
+ # --- ★★★ 关键修复: 正确设置 Qwen 的停止 Token ★★★ ---
49
+ if tokenizer:
50
+ # 1. 获取 <|im_end|> 的 ID (通常是 151645)
51
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
52
+
53
+ # 2. 获取 <|endoftext|> 的 ID (通常是 151643)
54
+ eot_id = tokenizer.eos_token_id
55
+
56
+ print(f"设置停止 IDs: <|im_end|_id={im_end_id}, <|endoftext|_id={eot_id}")
57
+
58
+ # 3. 创建 GenerationConfig
59
+ generation_config = GenerationConfig(
60
+ do_sample=False,
61
+ max_new_tokens=512,
62
+
63
+ # 关键(1): 告诉 generate() 遇到 *这两个* token 中的任何一个都要停止
64
+ eos_token_id=[im_end_id, eot_id],
65
+
66
+ # 关键(2): 告诉 generate() 在批处理(batching)时使用哪个 token 进行填充
67
+ # (我们使用 <|endoftext|>)
68
+ pad_token_id=eot_id
69
+ )
70
+ else:
71
+ # 备用配置,以防 tokenizer 加载失败
72
+ generation_config = GenerationConfig(
73
+ do_sample=False,
74
+ max_new_tokens=512
75
+ )
76
+
77
+ # 语言代码到全名的映射 (保持不变)
78
  source_lang_name_map = {
79
  "en": "english",
80
  "ja": "japanese",
 
96
  "de": "german",
97
  }
98
  # --------------------------------------------------------------------------
99
+ # 2. 定义核心翻译函数 (修改版)
100
  # --------------------------------------------------------------------------
101
  def translate(source_text, source_lang_code, target_lang_code):
102
  """
103
+ 接收用户输入并返回翻译结果 (使用 Transformers)
104
  """
105
+ if model is None or tokenizer is None:
106
+ return "错误:模型或 Tokenizer 未能成功加载,请检查 Space 日志。"
107
 
108
  # 简单的输入验证
109
  if not source_text or not source_text.strip():
 
112
  source_language_name = source_lang_name_map.get(source_lang_code, "the source language")
113
  target_language_name = target_lang_name_map.get(target_lang_code, "the target language")
114
 
115
+ # 构建与 vLLM 版本相同的提示
116
  prompt = (
117
  f"Human: Please translate the following text into {target_language_name}: \n"
118
  f"{source_text}<|im_end|>\n"
119
  f"Assistant:"
120
  )
121
+ print("--- Prompt ---")
122
  print(prompt)
123
+ print("--------------")
124
+
125
+ try:
126
+ # 1. 编码 (Tokenize)
127
+ # CausalLM 需要将 "Human: ... Assistant:" 整个作为输入
128
+ inputs = tokenizer(prompt, return_tensors="pt")
129
+
130
+ # 2. 将输入张量移动到模型所在的设备
131
+ # (当使用 device_map="auto" 时, model.device 指向第一个设备)
132
+ inputs = inputs.to(model.device)
133
+
134
+ # 3. 生成 (Generate)
135
+ with torch.no_grad(): # 推理时不需要计算梯度
136
+ outputs = model.generate(
137
+ **inputs,
138
+ generation_config=generation_config
139
+ )
140
+
141
+ # 4. 解码 (Decode)
142
+ # outputs[0] 包含了 "input_ids + generated_ids"
143
+ # 我们需要从 "input_ids" 之后开始解码
144
+ input_length = inputs.input_ids.shape[1]
145
+ generated_ids = outputs[0][input_length:]
146
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
147
+
148
+ return generated_text
149
+
150
+ except Exception as e:
151
+ print(f"翻译过程中出错: {e}")
152
+ return f"翻译时发生错误: {e}"
153
 
154
  # --------------------------------------------------------------------------
155
+ # 3. 创建并配置 Gradio 界面 (这部分保持不变)
156
  # --------------------------------------------------------------------------
157
 
158
+ # <--- 定义自定义 CSS 样式 --->
159
  css = """
160
  /* --- 1. 整体背景 (改为更高级的浅灰蓝渐变) --- */
161
  .gradio-container {
 
265
  )
266
 
267
  # --- (新位置) 支持的语向卡片 ---
 
268
  gr.HTML(f"""
269
  <div style="color: #444; font-size: 16px; margin-top: 30px; padding: 20px 25px; background-color: #FFFFFF; border-radius: 15px; max-width: 900px; margin-left: auto; margin-right: auto; box-shadow: 0 4px 20px rgba(0,0,0,0.05);">
270
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
- vllm==0.10.0
2
- gradio==5.49.1
 
 
1
+ Transformers==4.55.0
2
+ gradio==5.49.1
3
+ tomli