Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +3 -3
- server/app.py +5 -0
inference.py
CHANGED
|
@@ -29,7 +29,7 @@ from openai import OpenAI
|
|
| 29 |
|
| 30 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-
|
| 33 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
|
| 34 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 35 |
SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
|
|
@@ -171,7 +171,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 171 |
# OpenEnv reset() may not accept task args via HTTP; we rely on
|
| 172 |
# NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
|
| 173 |
# pass it as a reset parameter if the server supports it.
|
| 174 |
-
result = await env.reset()
|
| 175 |
obs = result.observation
|
| 176 |
|
| 177 |
for step in range(1, MAX_STEPS + 1):
|
|
@@ -191,7 +191,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 191 |
sql = call_llm(client, user_prompt)
|
| 192 |
|
| 193 |
from client import NL2SQLAction # local to avoid circular at module level
|
| 194 |
-
result = await env.step(
|
| 195 |
obs = result.observation
|
| 196 |
|
| 197 |
reward = obs.reward or 0.0
|
|
|
|
| 29 |
|
| 30 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 33 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
|
| 34 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 35 |
SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
|
|
|
|
| 171 |
# OpenEnv reset() may not accept task args via HTTP; we rely on
|
| 172 |
# NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
|
| 173 |
# pass it as a reset parameter if the server supports it.
|
| 174 |
+
result = await env.reset(options={"task": task_name}) # changed
|
| 175 |
obs = result.observation
|
| 176 |
|
| 177 |
for step in range(1, MAX_STEPS + 1):
|
|
|
|
| 191 |
sql = call_llm(client, user_prompt)
|
| 192 |
|
| 193 |
from client import NL2SQLAction # local to avoid circular at module level
|
| 194 |
+
result = await env.step({"query": sql}) # changed
|
| 195 |
obs = result.observation
|
| 196 |
|
| 197 |
reward = obs.reward or 0.0
|
server/app.py
CHANGED
|
@@ -30,6 +30,11 @@ app = create_fastapi_app(
|
|
| 30 |
observation_cls=NL2SQLObservation
|
| 31 |
)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def main():
|
| 34 |
import uvicorn
|
| 35 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 30 |
observation_cls=NL2SQLObservation
|
| 31 |
)
|
| 32 |
|
| 33 |
+
@app.on_event("startup")
|
| 34 |
+
async def startup_event():
|
| 35 |
+
from db.seed import seed_database
|
| 36 |
+
seed_database() # This creates the 'ecommerce.db' file inside the container
|
| 37 |
+
|
| 38 |
def main():
|
| 39 |
import uvicorn
|
| 40 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|