Sixparticle commited on
Commit
e20ba09
·
1 Parent(s): 83d6a1c

Remove added_tokens file before tokenizer init

Browse files
Files changed (2) hide show
  1. app.py +20 -3
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,16 +1,33 @@
1
  import gradio as gr
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
3
  import torch
4
 
5
  # 加载 CodeT5+ 模型
6
  model_name = "Salesforce/codet5p-220m"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=False)
9
  except Exception:
10
  # Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
11
- tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=False)
12
 
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=False)
14
 
15
  def generate_code(prompt: str, max_length: int = 128) -> str:
16
  """代码生成/补全"""
 
1
  import gradio as gr
2
+ import os
3
+ from huggingface_hub import snapshot_download
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
5
  import torch
6
 
7
  # 加载 CodeT5+ 模型
8
  model_name = "Salesforce/codet5p-220m"
9
+
10
+
11
+ def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
12
+ snapshot_download(repo_id=repo_id, local_dir=local_dir)
13
+
14
+ # Work around a transformers/tokenizers incompatibility for this repo.
15
+ # Its added_tokens.json is an empty dict, which can crash tokenizer init in some versions.
16
+ added_tokens_file = os.path.join(local_dir, "added_tokens.json")
17
+ if os.path.exists(added_tokens_file):
18
+ os.remove(added_tokens_file)
19
+
20
+ return local_dir
21
+
22
+
23
+ local_model_dir = prepare_local_model(model_name)
24
  try:
25
+ tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False)
26
  except Exception:
27
  # Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
28
+ tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False)
29
 
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False)
31
 
32
  def generate_code(prompt: str, max_length: int = 128) -> str:
33
  """代码生成/补全"""
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers>=4.40.0
 
2
  torch>=2.0.0
3
  sentencepiece>=0.1.96
4
  accelerate>=0.20.0
 
1
  transformers>=4.40.0
2
+ huggingface_hub>=0.23.0
3
  torch>=2.0.0
4
  sentencepiece>=0.1.96
5
  accelerate>=0.20.0