Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from fastapi import FastAPI, Request | |
| from my_model import StoryPointIncrementModel | |
| from transformers import AutoTokenizer | |
| # Define the temporary cache directory for all Hugging Face operations | |
| CACHE_DIR = "/tmp/hf" | |
| MAX_RETRIES = 3 | |
| # ---------------------------- | |
| # Hugging Face writable paths & Timeout Fix | |
| # ---------------------------- | |
| os.environ["HF_HOME"] = CACHE_DIR | |
| os.environ["HF_HUB_CACHE"] = CACHE_DIR | |
| os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR | |
| os.environ["HF_DATASETS_CACHE"] = CACHE_DIR | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # FIX: Set a high download timeout (300 seconds) | |
| os.environ["HF_DOWNLOAD_TIMEOUT"] = "300" | |
| # FIX: Enable LFS transfer protocols to potentially bypass network proxy issues | |
| os.environ["HF_HUB_ENABLE_GATING"] = "1" # <-- NEW ATTEMPT TO FIX DOWNLOAD STALL | |
| # CRITICAL FIX: Base model with hidden dimension of 384. | |
| MODEL_BASE_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| # --- AUTHENTICATION FIX START --- | |
| HF_AUTH_TOKEN = os.environ.get("HF_TOKEN") | |
| # --- AUTHENTICATION FIX END --- | |
| # ---------------------------- | |
| # Load safetensors model and tokenizer | |
| # ---------------------------- | |
| model_path = None | |
| try: | |
| if HF_AUTH_TOKEN: | |
| print(f"DEBUG: HF_TOKEN successfully retrieved. Starts with: {HF_AUTH_TOKEN[:5]}...") | |
| else: | |
| print("DEBUG: HF_TOKEN is NOT set in environment variables.") | |
| # 1. Download the checkpoint file with retry logic | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| print(f"Attempting to download model (Attempt {attempt + 1}/{MAX_RETRIES})...") | |
| # *** CRITICAL FIX: TARGET THE NEW, SMALLER QUANTIZED FILE *** | |
| model_path = hf_hub_download( | |
| repo_id="AgileGenAI/JIRA-story-point-increment-predictor", | |
| filename="model_fp16.safetensors", # <-- TARGETING 45.5 MB FILE | |
| cache_dir=CACHE_DIR, | |
| token=HF_AUTH_TOKEN, | |
| force_download=True | |
| ) | |
| print("Model download succeeded.") | |
| break | |
| except Exception as download_error: | |
| if attempt < MAX_RETRIES - 1: | |
| print(f"Download attempt {attempt + 1} failed: {download_error}. Retrying in 5 seconds...") | |
| time.sleep(5) | |
| else: | |
| raise download_error | |
| if model_path is None: | |
| raise Exception("Failed to download model after all retries.") | |
| # 2. Load the state dictionary | |
| state_dict = load_file(model_path) | |
| # 3. Initialize the model | |
| model = StoryPointIncrementModel(model_name=MODEL_BASE_NAME, cache_dir=CACHE_DIR) | |
| print("Model initialized from quantized file.") | |
| # --- Critical Fix: Simplified Key Mapping to find the missing regressor --- | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| # NOTE: Since the file is pre-quantized, we don't need the messy in-code conversion logic. | |
| if k.startswith('bert.'): | |
| new_state_dict[f'encoder.{k}'] = v | |
| elif 'embeddings.' in k or 'encoder.' in k or 'pooler.' in k: | |
| new_state_dict[f'encoder.{k}'] = v | |
| elif k.startswith('regressor.'): | |
| new_state_dict[k] = v | |
| elif k.startswith('classifier.') or k.startswith('linear.') or k.startswith('output.'): | |
| new_state_dict[f'regressor.{k.split(".", 1)[1]}'] = v | |
| elif k == 'weight' or k == 'bias': | |
| new_state_dict[f'regressor.{k}'] = v | |
| else: | |
| new_state_dict[k] = v | |
| # Load the state dictionary. Setting strict=False is critical. | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.eval() | |
| # 4. Load the tokenizer (tokenizer is not affected by precision) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE_NAME, cache_dir=CACHE_DIR) | |
| except Exception as e: | |
| print(f"An error occurred during model loading: {e}") | |
| model = None | |
| tokenizer = None | |
| # ---------------------------- | |
| # FastAPI app | |
| # ---------------------------- | |
| app = FastAPI() | |
| def read_root(): | |
| """Returns a simple status message for the root path.""" | |
| return { | |
| "status": "Story Point Prediction API is running", | |
| "message": "Use the /predict endpoint with a POST request and JSON body for predictions." | |
| } | |
| async def predict(request: Request): | |
| if model is None or tokenizer is None: | |
| return {"error": "Model initialization failed. Cannot predict. Check startup logs for errors."} | |
| data = await request.json() | |
| description = data.get("description", "") | |
| summary = data.get("summary", "") | |
| # Combine description and summary into a single input text | |
| text_input = f"{summary} [SEP] {description}" | |
| # Tokenize the input text | |
| encoded_input = tokenizer( | |
| text_input, | |
| return_tensors='pt', | |
| padding='max_length', | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| input_ids = encoded_input['input_ids'] | |
| attention_mask = encoded_input['attention_mask'] | |
| # Predict increment | |
| with torch.no_grad(): | |
| # Pass input_ids and attention_mask to the model | |
| output = model(input_ids=input_ids, attention_mask=attention_mask) | |
| # The output is a tensor of shape [1, 1], extract and round the value | |
| story_point_increment = int(round(output.item())) | |
| return {"story_point_increment": story_point_increment} | |