sdafd's picture
Create app.py
96e31d2 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
import time
app = FastAPI()
# Global variables to store the model and tokenizer
model = None
tokenizer = None
model_loading_lock = threading.Lock()
model_loaded = False # Status flag to indicate if the model is loaded
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
global model, tokenizer, model_loaded
with model_loading_lock:
if not model_loaded:
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="sequential",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offload"
)
model_loaded = True
print("Model loaded successfully.")
else:
print("Model already loaded.")
def check_model_status():
"""Check if the model is loaded and reload if necessary."""
global model_loaded
if not model_loaded:
print("Model not loaded. Reloading...")
load_model()
return model_loaded
@app.post("/chat")
async def chat_endpoint(message: str, temperature: float = 0.7, max_new_tokens: int = 2048):
global model, tokenizer
# Ensure the model is loaded before proceeding
if not check_model_status():
raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
stop_tokens = ["|im_end|"]
prompt = f"Human: {message}\n\nAssistant:"
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Stream the response
start_time = time.time()
token_count = 0
# Create a TextStreamer for token streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer # Use the TextStreamer here
)
# Start generation in a separate thread
threading.Thread(target=model.generate, kwargs=generate_kwargs).start()
def generate_response():
outputs = []
for new_token in streamer:
outputs.append(new_token)
token_count += 1
# Calculate tokens per second
elapsed_time = time.time() - start_time
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
# Yield the current output and token status
yield f"data: {new_token}\n\n"
if any(stop_token in new_token for stop_token in stop_tokens):
break
return StreamingResponse(generate_response(), media_type="text/event-stream")
@app.post("/reload-model")
async def reload_model():
"""Reload the model manually via an API endpoint."""
global model_loaded
model_loaded = False
load_model()
return {"message": "Model reloaded successfully."}
@app.get("/status")
async def get_model_status():
"""Check the status of the model."""
status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
return {"status": status}
# Load the model when the server starts
if __name__ == "__main__":
load_model() # Pre-load the model
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)