josStorer's picture
upload
a686568
raw
history blame
3.27 kB
import json
from typing import List
import torch
from fastapi import FastAPI, Request, status, HTTPException
from pydantic import BaseModel
from torch.cuda import get_device_properties
from transformers import AutoModel, AutoTokenizer
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pyngrok import ngrok, conf
import os
os.environ['TRANSFORMERS_CACHE'] = ".cache"
bits = 4
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so"
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d"
cache_dir = './models'
model_name = 'chatglm-6b-int4'
min_memory = 5.5
tokenizer = None
model = None
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event('startup')
def init():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory:
model = model.half().quantize(bits=bits).cuda()
print("Using GPU")
else:
model = model.float().quantize(bits=bits)
if torch.cuda.is_available():
print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3)
else:
print("No GPU available")
print("Using CPU")
model = model.eval()
class Message(BaseModel):
role: str
content: str
class Body(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.post("/chat/completions")
async def completions(body: Body, request: Request):
if not body.stream or body.model != model_name:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented")
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
user_question = ''
history = []
for message in body.messages:
if message.role == 'user':
user_question = message.content
elif message.role == 'system' or message.role == 'assistant':
assistant_answer = message.content
history.append((user_question, assistant_answer))
async def event_generator():
for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)):
if await request.is_disconnected():
return
yield json.dumps({"response": response[0]})
yield "[DONE]"
return EventSourceResponse(event_generator())
def ngrok_connect():
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000)
print(http_tunnel.public_url)
if __name__ == "__main__":
ngrok_connect()
uvicorn.run("main:app", reload=True, app_dir=".")