Yapp99 commited on
Commit
b37221e
1 Parent(s): 0d3b8dc

Some minor tweaks

Browse files
Files changed (3) hide show
  1. api.py +3 -2
  2. llm_backend.py +1 -1
  3. schema.py +1 -2
api.py CHANGED
@@ -18,7 +18,7 @@ logger = logging.getLogger("uvicorn.error")
18
  @app.get("/")
19
  def index():
20
  logger.info("this is a debug message")
21
- return {"hello": "world"}
22
 
23
 
24
  @app.post("/chat_stream")
@@ -68,6 +68,7 @@ def chat(request: ChatRequest):
68
  }
69
  try:
70
  output = chat_with_model(request.chat_history, request.model, kwargs)
71
- return HTMLResponse(output, media_type="text/plain")
 
72
  except Exception as e:
73
  raise HTTPException(status_code=500, detail=str(e))
 
18
  @app.get("/")
19
  def index():
20
  logger.info("this is a debug message")
21
+ return {"Hello": "world"}
22
 
23
 
24
  @app.post("/chat_stream")
 
68
  }
69
  try:
70
  output = chat_with_model(request.chat_history, request.model, kwargs)
71
+ return {"response": output}
72
+ # return HTMLResponse(output, media_type="text/plain")
73
  except Exception as e:
74
  raise HTTPException(status_code=500, detail=str(e))
llm_backend.py CHANGED
@@ -66,7 +66,7 @@ def chat_with_model(chat_history, model, kwargs: dict):
66
  input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
67
  response = llm.__call__(prompt, **input_kwargs)
68
 
69
- return response["choices"][0]["text"]
70
 
71
 
72
  # %% example input
 
66
  input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
67
  response = llm.__call__(prompt, **input_kwargs)
68
 
69
+ return response["choices"][0]["text"].strip()
70
 
71
 
72
  # %% example input
schema.py CHANGED
@@ -25,7 +25,7 @@ MODEL_ARGS = {
25
 
26
  logger = logging.getLogger("uvicorn.error")
27
  for model_arg in MODEL_ARGS.values():
28
- logger.info("this is a debug message")
29
  hf_hub_download(**model_arg)
30
 
31
 
@@ -37,7 +37,6 @@ class Message(BaseModel):
37
  class ChatRequest(BaseModel):
38
  chat_history: List[Message]
39
  model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
40
- stream: bool = False
41
  max_tokens: Optional[int] = 65536
42
  temperature: float = 0.8
43
  top_p: float = 0.95
 
25
 
26
  logger = logging.getLogger("uvicorn.error")
27
  for model_arg in MODEL_ARGS.values():
28
+ logger.info(f"Checking for {model_arg['repo_id']}")
29
  hf_hub_download(**model_arg)
30
 
31
 
 
37
  class ChatRequest(BaseModel):
38
  chat_history: List[Message]
39
  model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
 
40
  max_tokens: Optional[int] = 65536
41
  temperature: float = 0.8
42
  top_p: float = 0.95