Luke Stanley commited on
Commit
469f650
1 Parent(s): ce5ad5f

Avoid unneeded imports, make serverless output more sensible, removing some debugging and comments

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. runpod_handler.py +1 -1
  3. utils.py +12 -10
app.py CHANGED
@@ -36,8 +36,9 @@ if LLM_WORKER == "http" or LLM_WORKER == "in_memory":
36
  from chill import improvement_loop
37
 
38
 
39
- def greet(text):
 
40
  return str(improvement_loop(text))
41
 
42
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
43
  demo.launch(max_threads=1, share=True)
 
36
  from chill import improvement_loop
37
 
38
 
39
+ def chill_out(text):
40
+ print("Got this input:", text)
41
  return str(improvement_loop(text))
42
 
43
+ demo = gr.Interface(fn=chill_out, inputs="text", outputs="text")
44
  demo.launch(max_threads=1, share=True)
runpod_handler.py CHANGED
@@ -33,7 +33,7 @@ def handler(job):
33
  print("schema", schema )
34
  output = llm_stream_sans_network_simple(prompt, schema)
35
  #print("got this output", str(output))
36
- return f"model:{filename}\n{output}"
37
 
38
  runpod.serverless.start({
39
  "handler": handler,
 
33
  print("schema", schema )
34
  output = llm_stream_sans_network_simple(prompt, schema)
35
  #print("got this output", str(output))
36
+ return output
37
 
38
  runpod.serverless.start({
39
  "handler": handler,
utils.py CHANGED
@@ -1,11 +1,9 @@
1
  import json
2
  from os import environ as env
3
  from typing import Any, Dict, Union
4
- # TODO: Make imports conditional on type of worker being used:
5
- import requests
6
 
 
7
  from huggingface_hub import hf_hub_download
8
- from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf
9
 
10
 
11
  # There are 3 ways to use the LLM model currently used:
@@ -18,7 +16,7 @@ from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf
18
  # The real OpenAI API has other ways to set the output format.
19
  # It's possible to switch to another LLM API by changing the llm_streaming function.
20
  # 3. Use the RunPod API, which is a paid service with severless GPU functions.
21
- # TODO: Update README with instructions on how to use the RunPod API and options.
22
 
23
  URL = "http://localhost:5834/v1/chat/completions"
24
  in_memory_llm = None
@@ -34,15 +32,19 @@ LLM_MODEL_PATH = env.get("LLM_MODEL_PATH", None)
34
  MAX_TOKENS = int(env.get("MAX_TOKENS", 1000))
35
  TEMPERATURE = float(env.get("TEMPERATURE", 0.3))
36
 
37
- if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0 and (LLM_WORKER == "in_memory" or LLM_WORKER == "http"):
 
 
38
  print(f"Using local model from {LLM_MODEL_PATH}")
39
- else:
40
  print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub")
41
  LLM_MODEL_PATH =hf_hub_download(
42
  repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"),
43
  filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"),
44
  )
45
  print(f"Model downloaded to {LLM_MODEL_PATH}")
 
 
46
 
47
  if in_memory_llm is None and LLM_WORKER == "in_memory":
48
  print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.")
@@ -152,8 +154,8 @@ def llm_stream_sans_network(
152
 
153
  # Function to call the RunPod API with a Pydantic model and movie name
154
  def llm_stream_serverless(prompt,model):
155
- RUNPOD_ENDPOINT_ID = env("RUNPOD_ENDPOINT_ID")
156
- RUNPOD_API_KEY = env("RUNPOD_API_KEY")
157
  url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync"
158
 
159
  headers = {
@@ -171,8 +173,8 @@ def llm_stream_serverless(prompt,model):
171
 
172
  response = requests.post(url, json=data, headers=headers)
173
  result = response.json()
174
- output = result.get('output', '').replace("model:mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf\n", "")
175
- print(output)
176
  return json.loads(output)
177
 
178
  def query_ai_prompt(prompt, replacements, model_class):
 
1
  import json
2
  from os import environ as env
3
  from typing import Any, Dict, Union
 
 
4
 
5
+ import requests
6
  from huggingface_hub import hf_hub_download
 
7
 
8
 
9
  # There are 3 ways to use the LLM model currently used:
 
16
  # The real OpenAI API has other ways to set the output format.
17
  # It's possible to switch to another LLM API by changing the llm_streaming function.
18
  # 3. Use the RunPod API, which is a paid service with severless GPU functions.
19
+ # See serverless.md for more information.
20
 
21
  URL = "http://localhost:5834/v1/chat/completions"
22
  in_memory_llm = None
 
32
  MAX_TOKENS = int(env.get("MAX_TOKENS", 1000))
33
  TEMPERATURE = float(env.get("TEMPERATURE", 0.3))
34
 
35
+ performing_local_inference = (LLM_WORKER == "in_memory" or LLM_WORKER == "http")
36
+
37
+ if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0:
38
  print(f"Using local model from {LLM_MODEL_PATH}")
39
+ if performing_local_inference and not LLM_MODEL_PATH:
40
  print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub")
41
  LLM_MODEL_PATH =hf_hub_download(
42
  repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"),
43
  filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"),
44
  )
45
  print(f"Model downloaded to {LLM_MODEL_PATH}")
46
+ if LLM_WORKER == "http" or LLM_WORKER == "in_memory":
47
+ from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf
48
 
49
  if in_memory_llm is None and LLM_WORKER == "in_memory":
50
  print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.")
 
154
 
155
  # Function to call the RunPod API with a Pydantic model and movie name
156
  def llm_stream_serverless(prompt,model):
157
+ RUNPOD_ENDPOINT_ID = env.get("RUNPOD_ENDPOINT_ID")
158
+ RUNPOD_API_KEY = env.get("RUNPOD_API_KEY")
159
  url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync"
160
 
161
  headers = {
 
173
 
174
  response = requests.post(url, json=data, headers=headers)
175
  result = response.json()
176
+ print(result)
177
+ output = result['output']
178
  return json.loads(output)
179
 
180
  def query_ai_prompt(prompt, replacements, model_class):