Yapp99 commited on
Commit
0d3b8dc
1 Parent(s): efea2bf

support for both streaming and non streaming

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. api.py +30 -3
  3. llm_backend.py +26 -9
  4. schema.py +1 -0
Dockerfile CHANGED
@@ -8,4 +8,4 @@ RUN pip install --no-cache-dir --upgrade -r /requirements.txt
8
 
9
  RUN useradd -m -u 1000 user
10
 
11
- CMD ["fastapi", "run", "api.py", "--host", "0.0.0.0", "--port", "7860"]
 
8
 
9
  RUN useradd -m -u 1000 user
10
 
11
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py CHANGED
@@ -1,8 +1,8 @@
1
- from fastapi.responses import StreamingResponse
2
  from fastapi import FastAPI, HTTPException
3
  import logging
4
 
5
- from llm_backend import chat_with_model
6
  from schema import ChatRequest
7
 
8
  """
@@ -26,6 +26,7 @@ def chat_stream(request: ChatRequest):
26
  kwargs = {
27
  "max_tokens": request.max_tokens,
28
  "temperature": request.temperature,
 
29
  "top_p": request.top_p,
30
  "min_p": request.min_p,
31
  "typical_p": request.typical_p,
@@ -40,7 +41,33 @@ def chat_stream(request: ChatRequest):
40
  "mirostat_eta": request.mirostat_eta,
41
  }
42
  try:
43
- token_generator = chat_with_model(request.chat_history, request.model, kwargs)
44
  return StreamingResponse(token_generator, media_type="text/plain")
45
  except Exception as e:
46
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.responses import StreamingResponse, HTMLResponse
2
  from fastapi import FastAPI, HTTPException
3
  import logging
4
 
5
+ from llm_backend import chat_with_model, stream_with_model
6
  from schema import ChatRequest
7
 
8
  """
 
26
  kwargs = {
27
  "max_tokens": request.max_tokens,
28
  "temperature": request.temperature,
29
+ "stream": True,
30
  "top_p": request.top_p,
31
  "min_p": request.min_p,
32
  "typical_p": request.typical_p,
 
41
  "mirostat_eta": request.mirostat_eta,
42
  }
43
  try:
44
+ token_generator = stream_with_model(request.chat_history, request.model, kwargs)
45
  return StreamingResponse(token_generator, media_type="text/plain")
46
  except Exception as e:
47
  raise HTTPException(status_code=500, detail=str(e))
48
+
49
+
50
+ @app.post("/chat")
51
+ def chat(request: ChatRequest):
52
+ kwargs = {
53
+ "max_tokens": request.max_tokens,
54
+ "temperature": request.temperature,
55
+ "stream": False,
56
+ "top_p": request.top_p,
57
+ "min_p": request.min_p,
58
+ "typical_p": request.typical_p,
59
+ "frequency_penalty": request.frequency_penalty,
60
+ "presence_penalty": request.presence_penalty,
61
+ "repeat_penalty": request.repeat_penalty,
62
+ "top_k": request.top_k,
63
+ "seed": request.seed,
64
+ "tfs_z": request.tfs_z,
65
+ "mirostat_mode": request.mirostat_mode,
66
+ "mirostat_tau": request.mirostat_tau,
67
+ "mirostat_eta": request.mirostat_eta,
68
+ }
69
+ try:
70
+ output = chat_with_model(request.chat_history, request.model, kwargs)
71
+ return HTMLResponse(output, media_type="text/plain")
72
+ except Exception as e:
73
+ raise HTTPException(status_code=500, detail=str(e))
llm_backend.py CHANGED
@@ -19,8 +19,7 @@ def get_llm(model_name):
19
 
20
  def format_chat(chat_history: list[Message]):
21
  """
22
- Formats chat history and user input into a single string
23
- suitable for the model.
24
  """
25
  messages = []
26
  for msg in chat_history:
@@ -29,13 +28,16 @@ def format_chat(chat_history: list[Message]):
29
  return "\n".join(messages) + "\nAssistant:"
30
 
31
 
32
- def chat_with_model(chat_history, model, kwargs: dict):
 
 
 
 
 
 
33
  prompt = format_chat(chat_history)
34
 
35
- default_kwargs = dict(
36
- max_tokens=2048,
37
- top_k=1,
38
- )
39
 
40
  forced_kwargs = dict(
41
  stop=["\nUser:", "\nAssistant:", "</s>"],
@@ -43,8 +45,6 @@ def chat_with_model(chat_history, model, kwargs: dict):
43
  stream=True,
44
  )
45
 
46
- llm = get_llm(model)
47
-
48
  input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
49
  response = llm.__call__(prompt, **input_kwargs)
50
 
@@ -52,6 +52,23 @@ def chat_with_model(chat_history, model, kwargs: dict):
52
  yield token["choices"][0]["text"]
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # %% example input
56
  # kwargs = dict(
57
  # temperature=1,
 
19
 
20
  def format_chat(chat_history: list[Message]):
21
  """
22
+ Formats chat history and user input into a single string suitable for the model.
 
23
  """
24
  messages = []
25
  for msg in chat_history:
 
28
  return "\n".join(messages) + "\nAssistant:"
29
 
30
 
31
+ default_kwargs = dict(
32
+ max_tokens=2048,
33
+ top_k=1,
34
+ )
35
+
36
+
37
+ def stream_with_model(chat_history, model, kwargs: dict):
38
  prompt = format_chat(chat_history)
39
 
40
+ llm = get_llm(model)
 
 
 
41
 
42
  forced_kwargs = dict(
43
  stop=["\nUser:", "\nAssistant:", "</s>"],
 
45
  stream=True,
46
  )
47
 
 
 
48
  input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
49
  response = llm.__call__(prompt, **input_kwargs)
50
 
 
52
  yield token["choices"][0]["text"]
53
 
54
 
55
+ def chat_with_model(chat_history, model, kwargs: dict):
56
+ prompt = format_chat(chat_history)
57
+
58
+ llm = get_llm(model)
59
+
60
+ forced_kwargs = dict(
61
+ stop=["\nUser:", "\nAssistant:", "</s>"],
62
+ echo=False,
63
+ stream=False,
64
+ )
65
+
66
+ input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
67
+ response = llm.__call__(prompt, **input_kwargs)
68
+
69
+ return response["choices"][0]["text"]
70
+
71
+
72
  # %% example input
73
  # kwargs = dict(
74
  # temperature=1,
schema.py CHANGED
@@ -37,6 +37,7 @@ class Message(BaseModel):
37
  class ChatRequest(BaseModel):
38
  chat_history: List[Message]
39
  model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
 
40
  max_tokens: Optional[int] = 65536
41
  temperature: float = 0.8
42
  top_p: float = 0.95
 
37
  class ChatRequest(BaseModel):
38
  chat_history: List[Message]
39
  model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
40
+ stream: bool = False
41
  max_tokens: Optional[int] = 65536
42
  temperature: float = 0.8
43
  top_p: float = 0.95