File size: 3,804 Bytes
96e31d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
import time

app = FastAPI()

# Global variables to store the model and tokenizer
model = None
tokenizer = None
model_loading_lock = threading.Lock()
model_loaded = False  # Status flag to indicate if the model is loaded

def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    global model, tokenizer, model_loaded
    with model_loading_lock:
        if not model_loaded:
            print("Loading model...")
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="sequential",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                offload_folder="offload"
            )
            model_loaded = True
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

def check_model_status():
    """Check if the model is loaded and reload if necessary."""
    global model_loaded
    if not model_loaded:
        print("Model not loaded. Reloading...")
        load_model()
    return model_loaded

@app.post("/chat")
async def chat_endpoint(message: str, temperature: float = 0.7, max_new_tokens: int = 2048):
    global model, tokenizer
    
    # Ensure the model is loaded before proceeding
    if not check_model_status():
        raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
    
    stop_tokens = ["|im_end|"]
    prompt = f"Human: {message}\n\nAssistant:"
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Stream the response
    start_time = time.time()
    token_count = 0
    
    # Create a TextStreamer for token streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=inputs.input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        streamer=streamer  # Use the TextStreamer here
    )
    
    # Start generation in a separate thread
    threading.Thread(target=model.generate, kwargs=generate_kwargs).start()
    
    def generate_response():
        outputs = []
        for new_token in streamer:
            outputs.append(new_token)
            token_count += 1
            
            # Calculate tokens per second
            elapsed_time = time.time() - start_time
            tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
            
            # Yield the current output and token status
            yield f"data: {new_token}\n\n"
            
            if any(stop_token in new_token for stop_token in stop_tokens):
                break
    
    return StreamingResponse(generate_response(), media_type="text/event-stream")

@app.post("/reload-model")
async def reload_model():
    """Reload the model manually via an API endpoint."""
    global model_loaded
    model_loaded = False
    load_model()
    return {"message": "Model reloaded successfully."}

@app.get("/status")
async def get_model_status():
    """Check the status of the model."""
    status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
    return {"status": status}

# Load the model when the server starts
if __name__ == "__main__":
    load_model()  # Pre-load the model
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)