Sixparticle commited on
Commit
83d6a1c
·
1 Parent(s): b51e859

Use slow Roberta tokenizer for CodeT5

Browse files
Files changed (2) hide show
  1. app.py +7 -30
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,39 +1,16 @@
1
  import gradio as gr
2
- import json
3
- import os
4
- from huggingface_hub import snapshot_download
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
 
8
  # 加载 CodeT5+ 模型
9
  model_name = "Salesforce/codet5p-220m"
 
 
 
 
 
10
 
11
-
12
- def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
13
- snapshot_download(repo_id=repo_id, local_dir=local_dir)
14
-
15
- added_tokens_file = os.path.join(local_dir, "added_tokens.json")
16
- if os.path.exists(added_tokens_file):
17
- with open(added_tokens_file, "r", encoding="utf-8") as f:
18
- data = json.load(f)
19
-
20
- # Ensure the file is a plain token list for compatibility with tokenizers.add_tokens.
21
- if isinstance(data, dict):
22
- normalized = list(data.keys())
23
- elif isinstance(data, list):
24
- normalized = [str(item) for item in data]
25
- else:
26
- normalized = []
27
-
28
- with open(added_tokens_file, "w", encoding="utf-8") as f:
29
- json.dump(normalized, f, ensure_ascii=False)
30
-
31
- return local_dir
32
-
33
-
34
- local_model_dir = prepare_local_model(model_name)
35
- tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=True, trust_remote_code=True)
36
- model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=True)
37
 
38
  def generate_code(prompt: str, max_length: int = 128) -> str:
39
  """代码生成/补全"""
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
 
 
 
3
  import torch
4
 
5
  # 加载 CodeT5+ 模型
6
  model_name = "Salesforce/codet5p-220m"
7
+ try:
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=False)
9
+ except Exception:
10
+ # Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
11
+ tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=False)
12
 
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def generate_code(prompt: str, max_length: int = 128) -> str:
16
  """代码生成/补全"""
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  transformers>=4.40.0
2
- huggingface_hub>=0.23.0
3
  torch>=2.0.0
4
  sentencepiece>=0.1.96
5
  accelerate>=0.20.0
 
1
  transformers>=4.40.0
 
2
  torch>=2.0.0
3
  sentencepiece>=0.1.96
4
  accelerate>=0.20.0