Spaces:
Running
Running
File size: 3,288 Bytes
70006b8 a21ad8a c9783ae a21ad8a 25d4a37 d6458c6 25d4a37 d6458c6 a21ad8a d6458c6 a21ad8a d6458c6 a21ad8a 70006b8 d6458c6 70006b8 d6458c6 70006b8 d6458c6 70006b8 d6458c6 70006b8 d6458c6 70006b8 d6458c6 70006b8 d6458c6 c9783ae d6458c6 c9783ae 70006b8 c9783ae d6458c6 a21ad8a 25d4a37 a21ad8a 00c5d27 a21ad8a 25d4a37 a21ad8a 25d4a37 a21ad8a 25d4a37 a21ad8a 25d4a37 c9783ae a21ad8a d6458c6 a21ad8a 70006b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import asyncio
import os
import logging
# Set up logging
logging.basicConfig(level=logging.DEBUG)
# Set cache directory (Change this to a writable directory if necessary)
os.environ["HF_HOME"] = "/tmp/cache" # You can modify this to any directory with write access
# FastAPI app
app = FastAPI()
# CORS Middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model and tokenizer
model = None
tokenizer = None
model_loaded = False
# Load model and tokenizer in the background
async def load_model():
global model, tokenizer, model_loaded
model_name = "microsoft/phi-2" # Use a different model if necessary (e.g., "gpt2" for testing)
try:
logging.info("Starting model and tokenizer loading...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/cache", use_fast=True)
# Load model with quantization
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
quantization_config=quantization_config,
cache_dir="/tmp/cache"
)
model_loaded = True
logging.info("Model and tokenizer loaded successfully")
except Exception as e:
logging.error(f"Failed to load model or tokenizer: {e}")
raise
# Startup event to trigger model loading
@app.on_event("startup")
async def startup_event():
logging.info("Application starting up...")
background_tasks = BackgroundTasks()
background_tasks.add_task(load_model)
@app.on_event("shutdown")
async def shutdown_event():
logging.info("Application shutting down...")
# Health check endpoint
@app.get("/health")
async def health():
logging.info("Health check requested")
status = {"status": "Server is running", "model_loaded": model_loaded}
return status
# Request body model
class Question(BaseModel):
question: str
# Async generator for streaming response
async def generate_response_chunks(prompt: str):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=300,
do_sample=True,
top_p=0.95,
temperature=0.7,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer = output_text[len(prompt):]
chunk_size = 10
for i in range(0, len(answer), chunk_size):
yield answer[i:i + chunk_size]
await asyncio.sleep(0.01)
# POST endpoint for asking questions
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)
|