Spaces:
Running
Running
Some minor tweaks
Browse files- api.py +3 -2
- llm_backend.py +1 -1
- schema.py +1 -2
api.py
CHANGED
@@ -18,7 +18,7 @@ logger = logging.getLogger("uvicorn.error")
|
|
18 |
@app.get("/")
|
19 |
def index():
|
20 |
logger.info("this is a debug message")
|
21 |
-
return {"
|
22 |
|
23 |
|
24 |
@app.post("/chat_stream")
|
@@ -68,6 +68,7 @@ def chat(request: ChatRequest):
|
|
68 |
}
|
69 |
try:
|
70 |
output = chat_with_model(request.chat_history, request.model, kwargs)
|
71 |
-
return
|
|
|
72 |
except Exception as e:
|
73 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
18 |
@app.get("/")
|
19 |
def index():
|
20 |
logger.info("this is a debug message")
|
21 |
+
return {"Hello": "world"}
|
22 |
|
23 |
|
24 |
@app.post("/chat_stream")
|
|
|
68 |
}
|
69 |
try:
|
70 |
output = chat_with_model(request.chat_history, request.model, kwargs)
|
71 |
+
return {"response": output}
|
72 |
+
# return HTMLResponse(output, media_type="text/plain")
|
73 |
except Exception as e:
|
74 |
raise HTTPException(status_code=500, detail=str(e))
|
llm_backend.py
CHANGED
@@ -66,7 +66,7 @@ def chat_with_model(chat_history, model, kwargs: dict):
|
|
66 |
input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
|
67 |
response = llm.__call__(prompt, **input_kwargs)
|
68 |
|
69 |
-
return response["choices"][0]["text"]
|
70 |
|
71 |
|
72 |
# %% example input
|
|
|
66 |
input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
|
67 |
response = llm.__call__(prompt, **input_kwargs)
|
68 |
|
69 |
+
return response["choices"][0]["text"].strip()
|
70 |
|
71 |
|
72 |
# %% example input
|
schema.py
CHANGED
@@ -25,7 +25,7 @@ MODEL_ARGS = {
|
|
25 |
|
26 |
logger = logging.getLogger("uvicorn.error")
|
27 |
for model_arg in MODEL_ARGS.values():
|
28 |
-
logger.info("
|
29 |
hf_hub_download(**model_arg)
|
30 |
|
31 |
|
@@ -37,7 +37,6 @@ class Message(BaseModel):
|
|
37 |
class ChatRequest(BaseModel):
|
38 |
chat_history: List[Message]
|
39 |
model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
|
40 |
-
stream: bool = False
|
41 |
max_tokens: Optional[int] = 65536
|
42 |
temperature: float = 0.8
|
43 |
top_p: float = 0.95
|
|
|
25 |
|
26 |
logger = logging.getLogger("uvicorn.error")
|
27 |
for model_arg in MODEL_ARGS.values():
|
28 |
+
logger.info(f"Checking for {model_arg['repo_id']}")
|
29 |
hf_hub_download(**model_arg)
|
30 |
|
31 |
|
|
|
37 |
class ChatRequest(BaseModel):
|
38 |
chat_history: List[Message]
|
39 |
model: Literal["llama3.2", "falcon-mamba", "mistral-nemo"] = "llama3.2"
|
|
|
40 |
max_tokens: Optional[int] = 65536
|
41 |
temperature: float = 0.8
|
42 |
top_p: float = 0.95
|