FINGU-AI commited on
Commit
845073e
1 Parent(s): 75dfc1d

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +100 -0
inference.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Dict, List, Optional
4
+
5
+ import torch
6
+ from fastapi import FastAPI, Request
7
+ from vllm import LLM, SamplingParams
8
+ from vllm.utils import random_uuid
9
+
10
+ from chat_template import format_chat
11
+
12
+ app = FastAPI()
13
+ logger = logging.getLogger()
14
+ logger.setLevel(logging.INFO)
15
+
16
+ # Load the model function
17
+ def model_fn(model_dir):
18
+ # The model is already in the container, so we don't need to download it
19
+ model = LLM(
20
+ model=model_dir, # Load from local path
21
+ trust_remote_code=True,
22
+ dtype="fp8", # Explicitly specifying FP8 quantization
23
+ gpu_memory_utilization=0.9, # Optimal GPU usage
24
+ )
25
+ return model
26
+
27
+ # Global model variable
28
+ model = None
29
+
30
+ @app.on_event("startup")
31
+ async def startup_event():
32
+ global model
33
+ model = model_fn("/opt/ml/model") # Ensure the correct path to the model
34
+
35
+ # Chat completion endpoint
36
+ @app.post("/v1/chat/completions")
37
+ async def chat_completions(request: Request):
38
+ try:
39
+ data = await request.json()
40
+
41
+ # Retrieve messages and format the prompt
42
+ messages = data.get("messages", [])
43
+ formatted_prompt = format_chat(messages)
44
+
45
+ # Build sampling parameters with flexibility
46
+ sampling_params = SamplingParams(
47
+ do_sample=data.get("do_sample", True),
48
+ temperature=data.get("temperature", 0.7),
49
+ top_p=data.get("top_p", 0.9),
50
+ max_new_tokens=data.get("max_new_tokens", 512),
51
+ top_k=data.get("top_k", -1), # Support for top-k sampling
52
+ repetition_penalty=data.get("repetition_penalty", 1.0),
53
+ length_penalty=data.get("length_penalty", 1.0),
54
+ stop_token_ids=data.get("stop_token_ids", None),
55
+ skip_special_tokens=data.get("skip_special_tokens", True)
56
+ )
57
+
58
+ # Handle optional vLLM-specific guided parameters if present
59
+ guided_params = data.get("guided_params", None)
60
+ if guided_params:
61
+ sampling_params.guided_choice = guided_params.get("guided_choice")
62
+ sampling_params.guided_json = guided_params.get("guided_json")
63
+ sampling_params.guided_regex = guided_params.get("guided_regex")
64
+
65
+ # Generate output
66
+ outputs = model.generate(formatted_prompt, sampling_params)
67
+ generated_text = outputs[0].outputs[0].text
68
+
69
+ # Build response similar to OpenAI format
70
+ response = {
71
+ "id": f"chatcmpl-{random_uuid()}",
72
+ "object": "chat.completion",
73
+ "created": int(torch.cuda.current_timestamp()),
74
+ "model": "qwen-72b",
75
+ "choices": [{
76
+ "index": 0,
77
+ "message": {
78
+ "role": "assistant",
79
+ "content": generated_text
80
+ },
81
+ "finish_reason": "stop"
82
+ }],
83
+ "usage": {
84
+ "prompt_tokens": len(formatted_prompt),
85
+ "completion_tokens": len(generated_text),
86
+ "total_tokens": len(formatted_prompt) + len(generated_text)
87
+ }
88
+ }
89
+
90
+ return response
91
+
92
+ except Exception as e:
93
+ logger.exception("Exception during prediction")
94
+ return {"error": str(e), "details": repr(e)}
95
+
96
+ # Health check endpoint
97
+ @app.get("/ping")
98
+ def ping():
99
+ logger.info("Ping request received")
100
+ return {"status": "healthy"}