The inference API is too slow.
#6
by
YernazarBis
- opened
Good day!
I am testing this model with transformers' pipeline, but an API response is too slow (~100s), could someone please take a look at my code, what am I doing wrong?
I am using NVIDIA A40 GPU server on RunPod (48GB VRAM, 48GB RAM).
import transformers
import torch
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
pipeline = None
def load_model():
global pipeline
model_id = "casperhansen/llama-3-70b-instruct-awq"
device_map = "auto"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.float16},
device_map=device_map)
load_model()
class TextGenerationResponse(BaseModel):
generated_text: str
class TextGenerationRequest(BaseModel):
systemPrompt: str
prompt: str
temperature: float
def generate_inner(request: TextGenerationRequest):
messages = [
{"role": "system", "content": request.systemPrompt},
{"role": "user", "content": request.prompt},
]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
prompt,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=request.temperature,
top_p=0.9,
)
text = outputs[0]["generated_text"][len(prompt):]
return text
@app
.post("/generate")
def generate(request: TextGenerationRequest):
if (request.temperature < 0.1 and request.temperature > 2): request.temperature = 0.6
generatedText = generate_inner(request)
return [
TextGenerationResponse(generated_text=generatedText)
]
I have the same problem and I solved it, but I couldn't reenact it in another place.
In my case, I think I solved it by installing a package I needed while installing 'vllm'.
Below is environment of can fast working
cuda : 12.2
nvidia-driver : 535.171.04
GPU: nvidia titan rtx d6 * 2ea
python packages:
autoawq==0.2.4
autoawq_kernels==0.0.6
tokenizers==0.19.1
torch==2.2.1
torchsummary==1.5.1
vllm==0.4.1
vllm_nccl_cu12==2.18.1.0.3.0
transformers==4.40.1