jgs-430
added app.py updates
0b45551
import os
import time
import torch
import numpy as np
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, Request
from my_model import StoryPointIncrementModel
from transformers import AutoTokenizer
# Define the temporary cache directory for all Hugging Face operations
CACHE_DIR = "/tmp/hf"
MAX_RETRIES = 3
# ----------------------------
# Hugging Face writable paths & Timeout Fix
# ----------------------------
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
# FIX: Set a high download timeout (300 seconds)
os.environ["HF_DOWNLOAD_TIMEOUT"] = "300"
# FIX: Enable LFS transfer protocols to potentially bypass network proxy issues
os.environ["HF_HUB_ENABLE_GATING"] = "1" # <-- NEW ATTEMPT TO FIX DOWNLOAD STALL
# CRITICAL FIX: Base model with hidden dimension of 384.
MODEL_BASE_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# --- AUTHENTICATION FIX START ---
HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")
# --- AUTHENTICATION FIX END ---
# ----------------------------
# Load safetensors model and tokenizer
# ----------------------------
model_path = None
try:
if HF_AUTH_TOKEN:
print(f"DEBUG: HF_TOKEN successfully retrieved. Starts with: {HF_AUTH_TOKEN[:5]}...")
else:
print("DEBUG: HF_TOKEN is NOT set in environment variables.")
# 1. Download the checkpoint file with retry logic
for attempt in range(MAX_RETRIES):
try:
print(f"Attempting to download model (Attempt {attempt + 1}/{MAX_RETRIES})...")
# *** CRITICAL FIX: TARGET THE NEW, SMALLER QUANTIZED FILE ***
model_path = hf_hub_download(
repo_id="AgileGenAI/JIRA-story-point-increment-predictor",
filename="model_fp16.safetensors", # <-- TARGETING 45.5 MB FILE
cache_dir=CACHE_DIR,
token=HF_AUTH_TOKEN,
force_download=True
)
print("Model download succeeded.")
break
except Exception as download_error:
if attempt < MAX_RETRIES - 1:
print(f"Download attempt {attempt + 1} failed: {download_error}. Retrying in 5 seconds...")
time.sleep(5)
else:
raise download_error
if model_path is None:
raise Exception("Failed to download model after all retries.")
# 2. Load the state dictionary
state_dict = load_file(model_path)
# 3. Initialize the model
model = StoryPointIncrementModel(model_name=MODEL_BASE_NAME, cache_dir=CACHE_DIR)
print("Model initialized from quantized file.")
# --- Critical Fix: Simplified Key Mapping to find the missing regressor ---
new_state_dict = {}
for k, v in state_dict.items():
# NOTE: Since the file is pre-quantized, we don't need the messy in-code conversion logic.
if k.startswith('bert.'):
new_state_dict[f'encoder.{k}'] = v
elif 'embeddings.' in k or 'encoder.' in k or 'pooler.' in k:
new_state_dict[f'encoder.{k}'] = v
elif k.startswith('regressor.'):
new_state_dict[k] = v
elif k.startswith('classifier.') or k.startswith('linear.') or k.startswith('output.'):
new_state_dict[f'regressor.{k.split(".", 1)[1]}'] = v
elif k == 'weight' or k == 'bias':
new_state_dict[f'regressor.{k}'] = v
else:
new_state_dict[k] = v
# Load the state dictionary. Setting strict=False is critical.
model.load_state_dict(new_state_dict, strict=False)
model.eval()
# 4. Load the tokenizer (tokenizer is not affected by precision)
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE_NAME, cache_dir=CACHE_DIR)
except Exception as e:
print(f"An error occurred during model loading: {e}")
model = None
tokenizer = None
# ----------------------------
# FastAPI app
# ----------------------------
app = FastAPI()
@app.get("/")
def read_root():
"""Returns a simple status message for the root path."""
return {
"status": "Story Point Prediction API is running",
"message": "Use the /predict endpoint with a POST request and JSON body for predictions."
}
@app.post("/predict")
async def predict(request: Request):
if model is None or tokenizer is None:
return {"error": "Model initialization failed. Cannot predict. Check startup logs for errors."}
data = await request.json()
description = data.get("description", "")
summary = data.get("summary", "")
# Combine description and summary into a single input text
text_input = f"{summary} [SEP] {description}"
# Tokenize the input text
encoded_input = tokenizer(
text_input,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=512
)
input_ids = encoded_input['input_ids']
attention_mask = encoded_input['attention_mask']
# Predict increment
with torch.no_grad():
# Pass input_ids and attention_mask to the model
output = model(input_ids=input_ids, attention_mask=attention_mask)
# The output is a tensor of shape [1, 1], extract and round the value
story_point_increment = int(round(output.item()))
return {"story_point_increment": story_point_increment}