|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
from huggingface_hub import login |
|
|
from fastapi import HTTPException |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
class DownloadRequest(BaseModel): |
|
|
model: str |
|
|
|
|
|
|
|
|
def check_model(model_name): |
|
|
""" |
|
|
检查模型是否存在 |
|
|
参数: model_name - 从 request 传递过来的模型名称 |
|
|
返回: (model_name, cache_dir, success) |
|
|
""" |
|
|
cache_dir = "./my_model_cache" |
|
|
|
|
|
|
|
|
model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}" |
|
|
snapshot_path = model_path / "snapshots" |
|
|
|
|
|
if snapshot_path.exists() and any(snapshot_path.iterdir()): |
|
|
print(f"✓ 模型 {model_name} 已存在于缓存中") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
return model_name, cache_dir, True |
|
|
except Exception as e: |
|
|
print(f"⚠ 加载现有模型失败: {e}") |
|
|
return model_name, cache_dir, False |
|
|
else: |
|
|
raise HTTPException(status_code=404, detail=f"模型 `{model_name}` 不存在,请先下载") |
|
|
|
|
|
|
|
|
def download_model(model_name): |
|
|
""" |
|
|
下载指定的模型 |
|
|
参数: model_name - 要下载的模型名称 |
|
|
返回: (success, message) |
|
|
""" |
|
|
cache_dir = "./my_model_cache" |
|
|
|
|
|
print(f"开始下载模型: {model_name}") |
|
|
print(f"缓存目录: {cache_dir}") |
|
|
|
|
|
|
|
|
token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
if token: |
|
|
try: |
|
|
print("登录 Hugging Face...") |
|
|
login(token=token) |
|
|
print("✓ HuggingFace 登录成功!") |
|
|
except Exception as e: |
|
|
print(f"⚠ 登录失败: {e}") |
|
|
print("继续使用公开模型") |
|
|
else: |
|
|
print("ℹ 未设置 HUGGINGFACE_TOKEN - 仅使用公开模型") |
|
|
|
|
|
try: |
|
|
|
|
|
print("正在下载 tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
print("✓ Tokenizer 下载成功!") |
|
|
|
|
|
|
|
|
print("正在下载模型...") |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
print("✓ 模型下载成功!") |
|
|
|
|
|
print(f"✓ 模型和 tokenizer 已成功下载到 {cache_dir}") |
|
|
return True, f"模型 {model_name} 下载成功" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ 下载模型时出错: {e}") |
|
|
return False, f"下载失败: {str(e)}" |
|
|
|
|
|
|
|
|
def initialize_pipeline(model_name): |
|
|
""" |
|
|
使用模型初始化 pipeline |
|
|
参数: model_name - 从 request 传递过来的模型名称 |
|
|
返回: (pipe, tokenizer, success) |
|
|
""" |
|
|
model_name, cache_dir, success = check_model(model_name) |
|
|
|
|
|
if not success: |
|
|
return None, None, False |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
|
|
|
print(f"使用 {model_name} 初始化 pipeline...") |
|
|
|
|
|
pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) |
|
|
print("✓ Pipeline 初始化成功!") |
|
|
|
|
|
return pipe, tokenizer, True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Pipeline 初始化失败: {e}") |
|
|
return None, None, False |
|
|
|