File size: 4,268 Bytes
9eae9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0b5928
9eae9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986c8f3
9eae9de
 
 
 
 
 
 
 
 
 
 
 
 
 
e0b5928
9eae9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import markdown
from typing import Callable, List, Dict, Any, Generator
from functools import partial

import fastapi
import uvicorn
from fastapi import HTTPException, Depends, Request
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from anyio import create_memory_object_stream
from anyio.to_thread import run_sync
from ctransformers import AutoModelForCausalLM
from pydantic import BaseModel

llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
                                           model_file="WizardCoder-15B-1.0.ggmlv3.q4_0.bin",
                                           model_type="starcoder",
                                           threads=8)
app = fastapi.FastAPI(title="🪄WizardCoder💫")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def index():
    html_content = """
    <html>
        <head>
        </head>
        <body style="font-family:system-ui">
            <h2><a href="https://huggingface.co/TheBloke/WizardCoder-15B-1.0-GGML">wizardcoder-ggml</a></h2>
            <h2><a href="https://nerddisco-wizardcoder-15b-1-0-ggmlv3-q4-0.hf.space/docs">FastAPI Docs</a></h2>
        </body>
    </html>
    """
    return HTMLResponse(content=html_content, status_code=200)

class ChatCompletionRequestV0(BaseModel):
    prompt: str

class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    messages: List[Message]
    max_tokens: int = 1024

@app.post("/v1/completions")
async def completion(request: ChatCompletionRequestV0, response_mode=None):
    response = llm(request.prompt)
    return response

async def generate_response(chat_chunks, llm):
    for chat_chunk in chat_chunks:
        response = {
            'choices': [
                {
                    'message': {
                        'role': 'system',
                        'content': llm.detokenize(chat_chunk)
                    },
                    'finish_reason': 'stop' if llm.is_eos_token(chat_chunk) else 'unknown'
                }
            ]
        }
        yield dict(data=json.dumps(response))
    yield dict(data="[DONE]")

@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest):
    combined_messages = ' '.join([message.content for message in request.messages])
    tokens = llm.tokenize(combined_messages)
    
    try:
        chat_chunks = llm.generate(tokens)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    return EventSourceResponse(generate_response(chat_chunks, llm))

async def stream_response(tokens, llm):
    try:
        iterator: Generator = llm.generate(tokens)
        for chat_chunk in iterator:
            response = {
                'choices': [
                    {
                        'message': {
                            'role': 'system',
                            'content': llm.detokenize(chat_chunk)
                        },
                        'finish_reason': 'stop' if llm.is_eos_token(chat_chunk) else 'unknown'
                    }
                ]
            }
            yield dict(data=json.dumps(response))
        yield dict(data="[DONE]")
    except Exception as e:
        print(f"Exception in event publisher: {str(e)}")

@app.post("/v2/chat/completions")
async def chatV2_endpoint(request: Request, body: ChatCompletionRequest):
    combined_messages = ' '.join([message.content for message in body.messages])
    tokens = llm.tokenize(combined_messages)

    return EventSourceResponse(stream_response(tokens, llm))

@app.post("/v0/chat/completions")
async def chat(request: ChatCompletionRequestV0, response_mode=None):
    tokens = llm.tokenize(request.prompt)
    async def server_sent_events(chat_chunks, llm):
        for chat_chunk in llm.generate(chat_chunks):
            yield dict(data=json.dumps(llm.detokenize(chat_chunk)))
        yield dict(data="[DONE]")

    return EventSourceResponse(server_sent_events(tokens, llm))

if __name__ == "__main__":
  uvicorn.run(app, host="0.0.0.0", port=8000)