ethnmcl commited on
Commit
93a3159
·
verified ·
1 Parent(s): 6182c59

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -42
main.py CHANGED
@@ -1,17 +1,19 @@
1
  import os
2
- from typing import Dict, Any
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
7
  import torch
8
 
9
- # ---- Config -----------------------------------------------------------------
10
- MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2") # NEW default
11
- BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2") # fallback if LoRA repo has no tokenizer
12
- HF_TOKEN = os.getenv("HF_TOKEN") # set if private
 
13
 
14
- app = FastAPI(title="Check-in GPT-2 API", version="1.2.0")
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
@@ -20,51 +22,69 @@ app.add_middleware(
20
  device = 0 if torch.cuda.is_available() else -1
21
  DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
- # ---- Tokenizer (with fallback for adapter-only repos) ------------------------
24
- def load_tokenizer(repo_id: str, token: str | None):
 
25
  try:
26
- tk = AutoTokenizer.from_pretrained(repo_id, token=token)
27
  if tk.pad_token is None:
28
  tk.pad_token = tk.eos_token
29
  return tk, repo_id, False
30
- except Exception as e_model_tok:
31
- # Adapter repos often don't include tokenizer files: fallback to base tokenizer
32
- tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=token)
33
  if tk.pad_token is None:
34
  tk.pad_token = tk.eos_token
35
  return tk, BASE_TOKENIZER, True
36
 
37
- tokenizer, tokenizer_source, tokenizer_fallback = load_tokenizer(MODEL_ID, HF_TOKEN)
 
 
 
 
38
 
39
- # ---- Model (plain or PEFT LoRA) ---------------------------------------------
40
- _merged = False
41
- try:
42
- model = AutoModelForCausalLM.from_pretrained(
43
- MODEL_ID,
44
- token=HF_TOKEN,
45
- dtype=DTYPE,
46
  device_map="auto" if torch.cuda.is_available() else None,
47
  )
48
- except Exception as e_plain:
49
- # Try PEFT (adapter) path
50
  try:
51
- from peft import AutoPeftModelForCausalLM
52
- model = AutoPeftModelForCausalLM.from_pretrained(
53
- MODEL_ID,
54
- token=HF_TOKEN,
55
- dtype=DTYPE,
56
- device_map="auto" if torch.cuda.is_available() else None,
57
- )
 
 
 
 
 
58
  try:
59
- model = model.merge_and_unload()
60
- _merged = True
61
- except Exception:
62
- _merged = False
63
- except Exception as e_peft:
64
- raise RuntimeError(
65
- f"Failed to load model '{MODEL_ID}'. "
66
- f"Plain load error: {e_plain}\nPEFT load error: {e_peft}"
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  pipe = pipeline(
70
  "text-generation",
@@ -73,7 +93,7 @@ pipe = pipeline(
73
  device=device,
74
  )
75
 
76
- # ---- Prompting ---------------------------------------------------------------
77
  PREFIX = "INPUT: "
78
  SUFFIX = "\nOUTPUT:"
79
  def make_prompt(user_input: str) -> str:
@@ -98,11 +118,13 @@ class GenerateResponse(BaseModel):
98
  def root():
99
  return {
100
  "message": "Check-in GPT-2 API. POST /generate",
101
- "model": MODEL_ID,
102
  "device": "cuda" if device == 0 else "cpu",
103
- "merged_lora": _merged,
104
  "tokenizer_source": tokenizer_source,
105
- "tokenizer_fallback_used": tokenizer_fallback,
 
 
106
  }
107
 
108
  @app.get("/health")
@@ -131,3 +153,4 @@ def generate(req: GenerateRequest):
131
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
132
  except Exception as e:
133
  raise HTTPException(status_code=500, detail=str(e))
 
 
1
  import os
2
+ from typing import Dict, Any, Optional, Tuple
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ from huggingface_hub.utils import RepositoryNotFoundError
8
  import torch
9
 
10
+ # ---- Config --------------------------------------------------------------
11
+ PREFERRED_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-lora-gpt2")
12
+ FALLBACK_IDS = ["ethnmcl/checkin-lora-gpt2", "distilgpt2"] # last-resort keeps API alive
13
+ BASE_TOKENIZER = os.getenv("BASE_TOKENIZER", "gpt2")
14
+ HF_TOKEN = os.getenv("HF_TOKEN")
15
 
16
+ app = FastAPI(title="Check-in GPT-2 API", version="1.3.0")
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
 
22
  device = 0 if torch.cuda.is_available() else -1
23
  DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
+ # ---- Helpers -------------------------------------------------------------
26
+ def _load_tokenizer(repo_id: str) -> Tuple:
27
+ """Try repo tokenizer, then fallback to base tokenizer."""
28
  try:
29
+ tk = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
30
  if tk.pad_token is None:
31
  tk.pad_token = tk.eos_token
32
  return tk, repo_id, False
33
+ except Exception:
34
+ tk = AutoTokenizer.from_pretrained(BASE_TOKENIZER, token=HF_TOKEN)
 
35
  if tk.pad_token is None:
36
  tk.pad_token = tk.eos_token
37
  return tk, BASE_TOKENIZER, True
38
 
39
+ def _try_plain(repo_id: str):
40
+ return AutoModelForCausalLM.from_pretrained(
41
+ repo_id, token=HF_TOKEN, dtype=DTYPE,
42
+ device_map="auto" if torch.cuda.is_available() else None,
43
+ )
44
 
45
+ def _try_peft(repo_id: str):
46
+ from peft import AutoPeftModelForCausalLM
47
+ m = AutoPeftModelForCausalLM.from_pretrained(
48
+ repo_id, token=HF_TOKEN, dtype=DTYPE,
 
 
 
49
  device_map="auto" if torch.cuda.is_available() else None,
50
  )
51
+ # Merge if available; ok if not
 
52
  try:
53
+ m = m.merge_and_unload()
54
+ merged = True
55
+ except Exception:
56
+ merged = False
57
+ return m, merged
58
+
59
+ def load_model_any(repo_id: str):
60
+ """Try plain, then PEFT; raise if both fail."""
61
+ try:
62
+ m = _try_plain(repo_id)
63
+ return m, False
64
+ except Exception as e_plain:
65
  try:
66
+ m, merged = _try_peft(repo_id)
67
+ return m, merged
68
+ except Exception as e_peft:
69
+ raise RuntimeError(f"load failed for {repo_id} | plain: {e_plain} | peft: {e_peft}")
70
+
71
+ # ---- Boot: try MODEL_ID first, then fallbacks ----------------------------
72
+ errors = {}
73
+ chosen_id: Optional[str] = None
74
+ merged_lora = False
75
+
76
+ trial_ids = [PREFERRED_ID] + [i for i in FALLBACK_IDS if i != PREFERRED_ID]
77
+ for rid in trial_ids:
78
+ try:
79
+ tokenizer, tokenizer_source, tokenizer_fallback_used = _load_tokenizer(rid)
80
+ model, merged_lora = load_model_any(rid)
81
+ chosen_id = rid
82
+ break
83
+ except Exception as e:
84
+ errors[rid] = str(e)
85
+
86
+ if chosen_id is None:
87
+ raise RuntimeError(f"All model loads failed. Errors: {errors}")
88
 
89
  pipe = pipeline(
90
  "text-generation",
 
93
  device=device,
94
  )
95
 
96
+ # ---- Prompting -----------------------------------------------------------
97
  PREFIX = "INPUT: "
98
  SUFFIX = "\nOUTPUT:"
99
  def make_prompt(user_input: str) -> str:
 
118
  def root():
119
  return {
120
  "message": "Check-in GPT-2 API. POST /generate",
121
+ "model_chosen": chosen_id,
122
  "device": "cuda" if device == 0 else "cpu",
123
+ "merged_lora": merged_lora,
124
  "tokenizer_source": tokenizer_source,
125
+ "tokenizer_fallback_used": tokenizer_fallback_used,
126
+ "attempt_errors": errors,
127
+ "env_MODEL_ID": PREFERRED_ID,
128
  }
129
 
130
  @app.get("/health")
 
153
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
154
  except Exception as e:
155
  raise HTTPException(status_code=500, detail=str(e))
156
+