ritvik360 commited on
Commit
25575d3
Β·
verified Β·
1 Parent(s): fcc5471

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +3 -3
  2. server/app.py +5 -0
inference.py CHANGED
@@ -29,7 +29,7 @@ from openai import OpenAI
29
 
30
  # ── Configuration ──────────────────────────────────────────────────────────
31
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
33
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
  SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
@@ -171,7 +171,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
171
  # OpenEnv reset() may not accept task args via HTTP; we rely on
172
  # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
173
  # pass it as a reset parameter if the server supports it.
174
- result = await env.reset()
175
  obs = result.observation
176
 
177
  for step in range(1, MAX_STEPS + 1):
@@ -191,7 +191,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
191
  sql = call_llm(client, user_prompt)
192
 
193
  from client import NL2SQLAction # local to avoid circular at module level
194
- result = await env.step(NL2SQLAction(query=sql))
195
  obs = result.observation
196
 
197
  reward = obs.reward or 0.0
 
29
 
30
  # ── Configuration ──────────────────────────────────────────────────────────
31
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
33
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
  SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
 
171
  # OpenEnv reset() may not accept task args via HTTP; we rely on
172
  # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
173
  # pass it as a reset parameter if the server supports it.
174
+ result = await env.reset(options={"task": task_name}) # changed
175
  obs = result.observation
176
 
177
  for step in range(1, MAX_STEPS + 1):
 
191
  sql = call_llm(client, user_prompt)
192
 
193
  from client import NL2SQLAction # local to avoid circular at module level
194
+ result = await env.step({"query": sql}) # changed
195
  obs = result.observation
196
 
197
  reward = obs.reward or 0.0
server/app.py CHANGED
@@ -30,6 +30,11 @@ app = create_fastapi_app(
30
  observation_cls=NL2SQLObservation
31
  )
32
 
 
 
 
 
 
33
  def main():
34
  import uvicorn
35
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
30
  observation_cls=NL2SQLObservation
31
  )
32
 
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ from db.seed import seed_database
36
+ seed_database() # This creates the 'ecommerce.db' file inside the container
37
+
38
  def main():
39
  import uvicorn
40
  uvicorn.run(app, host="0.0.0.0", port=7860)