DataBob commited on
Commit
19c5d4e
·
1 Parent(s): f40151d

query langchain

Browse files
Files changed (3) hide show
  1. app.py +39 -19
  2. query_data.py +55 -0
  3. requirements.txt +6 -1
app.py CHANGED
@@ -1,15 +1,41 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
  import logging
4
 
 
5
  logger = logging.getLogger(__name__)
6
 
7
  """
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
11
 
12
- logger.info("### Inference client: zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def respond(
15
  message,
@@ -30,28 +56,22 @@ def respond(
30
  messages.append({"role": "user", "content": message})
31
 
32
  logger.info(messages)
33
- response = ""
34
-
35
- for message in client.chat_completion(
36
- messages,
37
- max_tokens=max_tokens,
38
- stream=True,
39
- temperature=temperature,
40
- top_p=top_p,
41
- ):
42
- token = message.choices[0].delta.content
43
-
44
- response += token
45
- yield response
46
 
47
 
48
  """
49
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
50
  """
51
- demo = gr.ChatInterface(
 
 
 
 
 
52
  respond,
53
  additional_inputs=[
54
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
55
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
56
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
57
  gr.Slider(
@@ -60,8 +80,8 @@ demo = gr.ChatInterface(
60
  value=0.95,
61
  step=0.05,
62
  label="Top-p (nucleus sampling)",
63
- ),
64
- ],
65
  )
66
 
67
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+
4
+ from query_data import query_data
5
+ from create_database import split_text
6
+ import os
7
+ import shutil
8
+
9
+
10
  import logging
11
 
12
+ logging.basicConfig(filename='myapp.log',format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
13
  logger = logging.getLogger(__name__)
14
 
15
  """
16
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
17
  """
 
18
 
19
+
20
+ CHROMA_PATH = "chroma"
21
+ DATA_PATH = "./data"
22
+
23
+
24
+ accesstoken = os.environ['HF_TOKEN']
25
+ checkpoint = "HuggingFaceH4/zephyr-7b-beta"
26
+ client = InferenceClient(checkpoint,token = accesstoken)
27
+
28
+ def upload_file(file):
29
+ if not os.path.exists(DATA_PATH):
30
+ os.mkdir(DATA_PATH)
31
+
32
+ shutil.copy(file,DATA_PATH)
33
+ gr.Info("File uploading")
34
+
35
+
36
+ logger.info("### Inference client: "+checkpoint)
37
+
38
+
39
 
40
  def respond(
41
  message,
 
56
  messages.append({"role": "user", "content": message})
57
 
58
  logger.info(messages)
59
+ response = query_data(message)
60
+ yield response
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  """
64
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
65
  """
66
+ with gr.Blocks() as demo:
67
+
68
+ upload_button = gr.UploadButton("Click the button to upload")
69
+ upload_button.upload(upload_file,upload_button)
70
+
71
+ gr.ChatInterface(
72
  respond,
73
  additional_inputs=[
74
+ gr.Textbox(value="You are a friendly Chatbot that helps searching knowledge into scientific articles.", label="System message"),
75
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
76
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
77
  gr.Slider(
 
80
  value=0.95,
81
  step=0.05,
82
  label="Top-p (nucleus sampling)",
83
+ )
84
+ ],
85
  )
86
 
87
 
query_data.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ # from dataclasses import dataclass
3
+ from langchain_chroma import Chroma
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_huggingface import HuggingFaceEndpoint
6
+
7
+ from langchain.prompts import ChatPromptTemplate
8
+
9
+ from langchain.chains import LLMChain
10
+ from langchain_core.prompts import PromptTemplate
11
+ import os
12
+
13
+ CHROMA_PATH = "chroma"
14
+
15
+ PROMPT_TEMPLATE = """
16
+ Answer the question based only on the following context:
17
+
18
+ {context}
19
+
20
+ ---
21
+
22
+ Answer the question based on the above context: {question}
23
+ """
24
+
25
+
26
+ def query_data(query_text):
27
+
28
+ # Prepare the DB.
29
+ embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
30
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
31
+
32
+ # Search the DB.
33
+ results = db.similarity_search_with_relevance_scores(query_text, k=3)
34
+ if len(results) == 0 or results[0][1] < 0.2:
35
+ print(f"Unable to find matching results.")
36
+ return
37
+
38
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
39
+ prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
40
+
41
+ repo_id = "HuggingFaceH4/zephyr-7b-beta"
42
+
43
+ llm = HuggingFaceEndpoint(
44
+ repo_id=repo_id,
45
+ max_length = 512,
46
+ temperature=0.5,
47
+ huggingfacehub_api_token=os.environ['HF_TOKEN'],
48
+ )
49
+ llm_chain = prompt_template | llm
50
+
51
+ response_text = llm_chain.invoke({"question": query_text, "context":context_text})
52
+
53
+ sources = [doc.metadata.get("source", None) for doc, _score in results]
54
+ formatted_response = f"{response_text}\nSources: {sources}"
55
+ return formatted_response
requirements.txt CHANGED
@@ -1 +1,6 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ tiktoken
3
+ langchain
4
+ langchain-community
5
+ langchain_chroma
6
+ langchain_huggingface