| import gradio as gr |
| import os |
| import json |
| import logging |
| import transformers |
| import huggingface_hub |
| from huggingface_hub import snapshot_download |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer |
| import torch |
|
|
| try: |
| import tokenizers |
| except Exception: |
| tokenizers = None |
|
|
| |
| model_name = "Salesforce/codet5p-220m" |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def log_runtime_versions() -> None: |
| """Log runtime package versions to simplify Space startup debugging.""" |
| tokenizers_version = getattr(tokenizers, "__version__", "not-installed") |
| logger.info("transformers version: %s", transformers.__version__) |
| logger.info("huggingface_hub version: %s", huggingface_hub.__version__) |
| logger.info("tokenizers version: %s", tokenizers_version) |
|
|
|
|
| def sanitize_added_tokens_file(added_tokens_file: str) -> None: |
| """Normalize added_tokens.json to dict format expected by slow tokenizers.""" |
| if not os.path.exists(added_tokens_file): |
| return |
|
|
| try: |
| with open(added_tokens_file, "r", encoding="utf-8") as fp: |
| data = json.load(fp) |
| except Exception: |
| data = {} |
|
|
| if isinstance(data, dict): |
| sanitized = {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, int)} |
| elif isinstance(data, list): |
| |
| sanitized = {} |
| else: |
| sanitized = {} |
|
|
| with open(added_tokens_file, "w", encoding="utf-8") as fp: |
| json.dump(sanitized, fp, ensure_ascii=True) |
|
|
|
|
| def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str: |
| snapshot_download(repo_id=repo_id, local_dir=local_dir) |
|
|
| |
| added_tokens_file = os.path.join(local_dir, "added_tokens.json") |
| sanitize_added_tokens_file(added_tokens_file) |
|
|
| return local_dir |
|
|
|
|
| log_runtime_versions() |
| local_model_dir = prepare_local_model(model_name) |
| auto_error = None |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False) |
| logger.info("Tokenizer loaded with AutoTokenizer (slow mode).") |
| except Exception as exc: |
| auto_error = exc |
| logger.warning("AutoTokenizer load failed, trying RobertaTokenizer fallback: %s", exc) |
| |
| try: |
| tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False) |
| logger.info("Tokenizer loaded with RobertaTokenizer fallback.") |
| except Exception as fallback_exc: |
| raise RuntimeError( |
| "Tokenizer initialization failed for both AutoTokenizer and RobertaTokenizer. " |
| f"AutoTokenizer error: {auto_error}; RobertaTokenizer error: {fallback_exc}" |
| ) from fallback_exc |
|
|
| model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False) |
|
|
| def generate_code(prompt: str, max_length: int = 128) -> str: |
| """代码生成/补全""" |
| if not prompt.strip(): |
| return "" |
| |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=max_length, |
| num_beams=4, |
| early_stopping=True |
| ) |
| |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| demo = gr.Interface( |
| fn=generate_code, |
| inputs=[ |
| gr.Textbox( |
| label="Prompt", |
| placeholder="输入代码描述或代码片段,例如:def fibonacci(n):", |
| lines=5 |
| ), |
| gr.Slider(32, 512, value=128, step=32, label="Max Length") |
| ], |
| outputs=gr.Textbox(label="Generated Code", lines=10), |
| title="CodeT5+ Code Generation", |
| description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。", |
| examples=[ |
| ["def fibonacci(n):", 128], |
| ["# Python function to calculate factorial", 128], |
| ["def quick_sort(arr):", 128], |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|