completed gradio app for qa
Browse files- app.py +23 -54
- app_modules/init.py +78 -0
- app_modules/llm_inference.py +2 -2
- app_modules/llm_qa_chain.py +1 -1
- app_modules/presets.py +0 -91
- app_modules/utils.py +0 -8
- test.py +43 -28
app.py
CHANGED
@@ -6,67 +6,36 @@ from timeit import default_timer as timer
|
|
6 |
|
7 |
import gradio as gr
|
8 |
from anyio.from_thread import start_blocking_portal
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from langchain.vectorstores.faiss import FAISS
|
12 |
|
13 |
-
|
14 |
-
from app_modules.qa_chain import QAChain
|
15 |
-
from app_modules.utils import *
|
16 |
|
17 |
-
# Constants
|
18 |
-
init_settings()
|
19 |
-
|
20 |
-
# https://github.com/huggingface/transformers/issues/17611
|
21 |
-
os.environ["CURL_CA_BUNDLE"] = ""
|
22 |
-
|
23 |
-
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
24 |
-
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
25 |
-
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
26 |
-
|
27 |
-
hf_embeddings_model_name = (
|
28 |
-
os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
|
29 |
-
)
|
30 |
-
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
31 |
-
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get("CHROMADB_INDEX_PATH")
|
32 |
-
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
33 |
-
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
34 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
|
35 |
show_param_settings = os.environ.get("SHOW_PARAM_SETTINGS") == "true"
|
36 |
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
start = timer()
|
42 |
-
embeddings = HuggingFaceInstructEmbeddings(
|
43 |
-
model_name=hf_embeddings_model_name,
|
44 |
-
model_kwargs={"device": hf_embeddings_device_type},
|
45 |
)
|
46 |
-
|
47 |
-
|
48 |
-
print(f"Completed in {end - start:.3f}s")
|
49 |
-
|
50 |
-
start = timer()
|
51 |
-
|
52 |
-
print(f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}")
|
53 |
|
54 |
-
|
55 |
-
raise ValueError(f"{index_path} does not exist!")
|
56 |
-
elif using_faiss:
|
57 |
-
vectorstore = FAISS.load_local(index_path, embeddings)
|
58 |
-
else:
|
59 |
-
vectorstore = Chroma(embedding_function=embeddings, persist_directory=index_path)
|
60 |
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
qa_chain = QAChain(vectorstore, llm_model_type)
|
67 |
-
qa_chain.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
68 |
-
end = timer()
|
69 |
-
print(f"Completed in {end - start:.3f}s")
|
70 |
|
71 |
|
72 |
def qa(chatbot):
|
@@ -77,7 +46,7 @@ def qa(chatbot):
|
|
77 |
|
78 |
def task(question, chat_history):
|
79 |
start = timer()
|
80 |
-
ret = qa_chain.
|
81 |
{"question": question, "chat_history": chat_history}, None, q
|
82 |
)
|
83 |
end = timer()
|
@@ -135,7 +104,7 @@ def qa(chatbot):
|
|
135 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
136 |
customCSS = f.read()
|
137 |
|
138 |
-
with gr.Blocks(css=customCSS
|
139 |
user_question = gr.State("")
|
140 |
with gr.Row():
|
141 |
gr.HTML(title)
|
@@ -219,5 +188,5 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
219 |
api_name="reset",
|
220 |
)
|
221 |
|
222 |
-
demo.title = "Chat with
|
223 |
-
demo.queue(concurrency_count=
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
from anyio.from_thread import start_blocking_portal
|
9 |
+
from app_modules.init import app_init
|
10 |
+
from app_modules.utils import print_llm_response
|
|
|
11 |
|
12 |
+
qa_chain = app_init()
|
|
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
|
15 |
show_param_settings = os.environ.get("SHOW_PARAM_SETTINGS") == "true"
|
16 |
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
|
17 |
|
18 |
+
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
|
19 |
+
model = (
|
20 |
+
"OpenAI GPT-4" if using_openai else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
|
|
|
|
|
|
|
|
|
21 |
)
|
22 |
+
href = "https://openai.com/gpt-4" if using_openai else f"https://huggingface.co/{model}"
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
title = """<h1 align="left" style="min-width:200px; margin-top:0;"> Chat with AI Books </h1>"""
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
description_top = f"""\
|
27 |
+
<div align="left">
|
28 |
+
<p> Currently Running: <a href="{href}">{model}</a></p>
|
29 |
+
</div>
|
30 |
+
"""
|
31 |
|
32 |
+
description = """\
|
33 |
+
<div align="center" style="margin:16px 0">
|
34 |
+
The demo is built on <a href="https://github.com/hwchase17/langchain">LangChain</a>.
|
35 |
+
</div>
|
36 |
+
"""
|
37 |
|
38 |
+
CONCURRENT_COUNT = 100
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def qa(chatbot):
|
|
|
46 |
|
47 |
def task(question, chat_history):
|
48 |
start = timer()
|
49 |
+
ret = qa_chain.call_chain(
|
50 |
{"question": question, "chat_history": chat_history}, None, q
|
51 |
)
|
52 |
end = timer()
|
|
|
104 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
105 |
customCSS = f.read()
|
106 |
|
107 |
+
with gr.Blocks(css=customCSS) as demo:
|
108 |
user_question = gr.State("")
|
109 |
with gr.Row():
|
110 |
gr.HTML(title)
|
|
|
188 |
api_name="reset",
|
189 |
)
|
190 |
|
191 |
+
demo.title = "Chat with AI Books"
|
192 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=share_gradio_app)
|
app_modules/init.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Main entrypoint for the app."""
|
2 |
+
import os
|
3 |
+
from timeit import default_timer as timer
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from dotenv import find_dotenv, load_dotenv
|
7 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
8 |
+
from langchain.vectorstores.chroma import Chroma
|
9 |
+
from langchain.vectorstores.faiss import FAISS
|
10 |
+
|
11 |
+
from app_modules.llm_loader import LLMLoader
|
12 |
+
from app_modules.llm_qa_chain import QAChain
|
13 |
+
from app_modules.utils import get_device_types, init_settings
|
14 |
+
|
15 |
+
found_dotenv = find_dotenv(".env")
|
16 |
+
|
17 |
+
if len(found_dotenv) == 0:
|
18 |
+
found_dotenv = find_dotenv(".env.example")
|
19 |
+
print(f"loading env vars from: {found_dotenv}")
|
20 |
+
load_dotenv(found_dotenv, override=False)
|
21 |
+
|
22 |
+
# Constants
|
23 |
+
init_settings()
|
24 |
+
|
25 |
+
|
26 |
+
def app_init():
|
27 |
+
# https://github.com/huggingface/transformers/issues/17611
|
28 |
+
os.environ["CURL_CA_BUNDLE"] = ""
|
29 |
+
|
30 |
+
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
31 |
+
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
32 |
+
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
33 |
+
|
34 |
+
hf_embeddings_model_name = (
|
35 |
+
os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
|
36 |
+
)
|
37 |
+
|
38 |
+
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
39 |
+
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
|
40 |
+
"CHROMADB_INDEX_PATH"
|
41 |
+
)
|
42 |
+
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
43 |
+
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
44 |
+
|
45 |
+
start = timer()
|
46 |
+
embeddings = HuggingFaceInstructEmbeddings(
|
47 |
+
model_name=hf_embeddings_model_name,
|
48 |
+
model_kwargs={"device": hf_embeddings_device_type},
|
49 |
+
)
|
50 |
+
end = timer()
|
51 |
+
|
52 |
+
print(f"Completed in {end - start:.3f}s")
|
53 |
+
|
54 |
+
start = timer()
|
55 |
+
|
56 |
+
print(f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}")
|
57 |
+
|
58 |
+
if not os.path.isdir(index_path):
|
59 |
+
raise ValueError(f"{index_path} does not exist!")
|
60 |
+
elif using_faiss:
|
61 |
+
vectorstore = FAISS.load_local(index_path, embeddings)
|
62 |
+
else:
|
63 |
+
vectorstore = Chroma(
|
64 |
+
embedding_function=embeddings, persist_directory=index_path
|
65 |
+
)
|
66 |
+
|
67 |
+
end = timer()
|
68 |
+
|
69 |
+
print(f"Completed in {end - start:.3f}s")
|
70 |
+
|
71 |
+
start = timer()
|
72 |
+
llm_loader = LLMLoader(llm_model_type)
|
73 |
+
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
74 |
+
qa_chain = QAChain(vectorstore, llm_loader)
|
75 |
+
end = timer()
|
76 |
+
print(f"Completed in {end - start:.3f}s")
|
77 |
+
|
78 |
+
return qa_chain
|
app_modules/llm_inference.py
CHANGED
@@ -8,8 +8,8 @@ from threading import Thread
|
|
8 |
from langchain.callbacks.tracers import LangChainTracer
|
9 |
from langchain.chains.base import Chain
|
10 |
|
11 |
-
from app_modules.llm_loader import
|
12 |
-
from app_modules.utils import
|
13 |
|
14 |
|
15 |
class LLMInference(metaclass=abc.ABCMeta):
|
|
|
8 |
from langchain.callbacks.tracers import LangChainTracer
|
9 |
from langchain.chains.base import Chain
|
10 |
|
11 |
+
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
|
12 |
+
from app_modules.utils import remove_extra_spaces
|
13 |
|
14 |
|
15 |
class LLMInference(metaclass=abc.ABCMeta):
|
app_modules/llm_qa_chain.py
CHANGED
@@ -9,7 +9,7 @@ class QAChain(LLMInference):
|
|
9 |
vectorstore: VectorStore
|
10 |
|
11 |
def __init__(self, vectorstore, llm_loader: int = 2048):
|
12 |
-
super.__init__(llm_loader)
|
13 |
self.vectorstore = vectorstore
|
14 |
|
15 |
def create_chain(self) -> Chain:
|
|
|
9 |
vectorstore: VectorStore
|
10 |
|
11 |
def __init__(self, vectorstore, llm_loader: int = 2048):
|
12 |
+
super().__init__(llm_loader)
|
13 |
self.vectorstore = vectorstore
|
14 |
|
15 |
def create_chain(self) -> Chain:
|
app_modules/presets.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
# -*- coding:utf-8 -*-
|
2 |
-
import os
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
-
|
6 |
-
from app_modules.utils import *
|
7 |
-
|
8 |
-
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
|
9 |
-
model = (
|
10 |
-
"OpenAI GPT-4" if using_openai else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
|
11 |
-
)
|
12 |
-
href = "https://openai.com/gpt-4" if using_openai else f"https://huggingface.co/{model}"
|
13 |
-
|
14 |
-
title = """<h1 align="left" style="min-width:200px; margin-top:0;"> Chat with AI Books </h1>"""
|
15 |
-
|
16 |
-
description_top = f"""\
|
17 |
-
<div align="left">
|
18 |
-
<p> Currently Running: <a href="{href}">{model}</a></p>
|
19 |
-
</div>
|
20 |
-
"""
|
21 |
-
|
22 |
-
description = """\
|
23 |
-
<div align="center" style="margin:16px 0">
|
24 |
-
The demo is built on <a href="https://github.com/hwchase17/langchain">LangChain</a>.
|
25 |
-
</div>
|
26 |
-
"""
|
27 |
-
CONCURRENT_COUNT = 100
|
28 |
-
|
29 |
-
|
30 |
-
ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
|
31 |
-
|
32 |
-
small_and_beautiful_theme = gr.themes.Soft(
|
33 |
-
primary_hue=gr.themes.Color(
|
34 |
-
c50="#02C160",
|
35 |
-
c100="rgba(2, 193, 96, 0.2)",
|
36 |
-
c200="#02C160",
|
37 |
-
c300="rgba(2, 193, 96, 0.32)",
|
38 |
-
c400="rgba(2, 193, 96, 0.32)",
|
39 |
-
c500="rgba(2, 193, 96, 1.0)",
|
40 |
-
c600="rgba(2, 193, 96, 1.0)",
|
41 |
-
c700="rgba(2, 193, 96, 0.32)",
|
42 |
-
c800="rgba(2, 193, 96, 0.32)",
|
43 |
-
c900="#02C160",
|
44 |
-
c950="#02C160",
|
45 |
-
),
|
46 |
-
secondary_hue=gr.themes.Color(
|
47 |
-
c50="#576b95",
|
48 |
-
c100="#576b95",
|
49 |
-
c200="#576b95",
|
50 |
-
c300="#576b95",
|
51 |
-
c400="#576b95",
|
52 |
-
c500="#576b95",
|
53 |
-
c600="#576b95",
|
54 |
-
c700="#576b95",
|
55 |
-
c800="#576b95",
|
56 |
-
c900="#576b95",
|
57 |
-
c950="#576b95",
|
58 |
-
),
|
59 |
-
neutral_hue=gr.themes.Color(
|
60 |
-
name="gray",
|
61 |
-
c50="#f9fafb",
|
62 |
-
c100="#f3f4f6",
|
63 |
-
c200="#e5e7eb",
|
64 |
-
c300="#d1d5db",
|
65 |
-
c400="#B2B2B2",
|
66 |
-
c500="#808080",
|
67 |
-
c600="#636363",
|
68 |
-
c700="#515151",
|
69 |
-
c800="#393939",
|
70 |
-
c900="#272727",
|
71 |
-
c950="#171717",
|
72 |
-
),
|
73 |
-
radius_size=gr.themes.sizes.radius_sm,
|
74 |
-
).set(
|
75 |
-
button_primary_background_fill="#06AE56",
|
76 |
-
button_primary_background_fill_dark="#06AE56",
|
77 |
-
button_primary_background_fill_hover="#07C863",
|
78 |
-
button_primary_border_color="#06AE56",
|
79 |
-
button_primary_border_color_dark="#06AE56",
|
80 |
-
button_primary_text_color="#FFFFFF",
|
81 |
-
button_primary_text_color_dark="#FFFFFF",
|
82 |
-
button_secondary_background_fill="#F2F2F2",
|
83 |
-
button_secondary_background_fill_dark="#2B2B2B",
|
84 |
-
button_secondary_text_color="#393939",
|
85 |
-
button_secondary_text_color_dark="#FFFFFF",
|
86 |
-
# background_fill_primary="#F7F7F7",
|
87 |
-
# background_fill_primary_dark="#1F1F1F",
|
88 |
-
block_title_text_color="*primary_500",
|
89 |
-
block_title_background_fill="*primary_100",
|
90 |
-
input_background_fill="#F6F6F6",
|
91 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_modules/utils.py
CHANGED
@@ -9,16 +9,8 @@ from pathlib import Path
|
|
9 |
|
10 |
import requests
|
11 |
import torch
|
12 |
-
from dotenv import find_dotenv, load_dotenv
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
-
found_dotenv = find_dotenv(".env")
|
16 |
-
if len(found_dotenv) == 0:
|
17 |
-
found_dotenv = find_dotenv(".env.example")
|
18 |
-
print(f"loading env vars from: {found_dotenv}")
|
19 |
-
load_dotenv(found_dotenv, override=False)
|
20 |
-
# print(f"loaded env vars: {os.environ}")
|
21 |
-
|
22 |
|
23 |
class LogRecord(logging.LogRecord):
|
24 |
def getMessage(self):
|
|
|
9 |
|
10 |
import requests
|
11 |
import torch
|
|
|
12 |
from tqdm import tqdm
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class LogRecord(logging.LogRecord):
|
16 |
def getMessage(self):
|
test.py
CHANGED
@@ -7,36 +7,21 @@ from timeit import default_timer as timer
|
|
7 |
from langchain.callbacks.base import BaseCallbackHandler
|
8 |
from langchain.schema import HumanMessage
|
9 |
|
|
|
10 |
from app_modules.llm_loader import LLMLoader
|
11 |
-
from app_modules.utils import
|
12 |
|
13 |
-
user_question = "What's the capital city of Malaysia?"
|
14 |
-
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
19 |
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
def reset(self):
|
26 |
-
self.texts = []
|
27 |
-
|
28 |
-
def get_standalone_question(self) -> str:
|
29 |
-
return self.texts[0].strip() if len(self.texts) > 0 else None
|
30 |
-
|
31 |
-
def on_llm_end(self, response, **kwargs) -> None:
|
32 |
-
"""Run when chain ends running."""
|
33 |
-
print("\non_llm_end - response:")
|
34 |
-
print(response)
|
35 |
-
self.texts.append(response.generations[0][0].text)
|
36 |
-
|
37 |
|
38 |
-
class TestLLMLoader(unittest.TestCase):
|
39 |
-
def run_test_case(self, llm_model_type, query):
|
40 |
llm_loader = LLMLoader(llm_model_type)
|
41 |
start = timer()
|
42 |
llm_loader.init(
|
@@ -53,16 +38,46 @@ class TestLLMLoader(unittest.TestCase):
|
|
53 |
print(result)
|
54 |
|
55 |
def test_openai(self):
|
56 |
-
self.run_test_case("openai",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def test_llamacpp(self):
|
59 |
-
self.run_test_case("llamacpp",
|
60 |
|
61 |
def test_gpt4all_j(self):
|
62 |
-
self.run_test_case("gpt4all-j",
|
63 |
|
64 |
def test_huggingface(self):
|
65 |
-
self.run_test_case("huggingface",
|
66 |
|
67 |
|
68 |
if __name__ == "__main__":
|
|
|
7 |
from langchain.callbacks.base import BaseCallbackHandler
|
8 |
from langchain.schema import HumanMessage
|
9 |
|
10 |
+
from app_modules.init import app_init
|
11 |
from app_modules.llm_loader import LLMLoader
|
12 |
+
from app_modules.utils import get_device_types, print_llm_response
|
13 |
|
|
|
|
|
14 |
|
15 |
+
class TestLLMLoader: # (unittest.TestCase):
|
16 |
+
question = "What's the capital city of Malaysia?"
|
|
|
17 |
|
18 |
+
def run_test_case(self, llm_model_type, query):
|
19 |
+
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
20 |
|
21 |
+
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
22 |
+
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
23 |
+
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
|
|
25 |
llm_loader = LLMLoader(llm_model_type)
|
26 |
start = timer()
|
27 |
llm_loader.init(
|
|
|
38 |
print(result)
|
39 |
|
40 |
def test_openai(self):
|
41 |
+
self.run_test_case("openai", self.question)
|
42 |
+
|
43 |
+
def test_llamacpp(self):
|
44 |
+
self.run_test_case("llamacpp", self.question)
|
45 |
+
|
46 |
+
def test_gpt4all_j(self):
|
47 |
+
self.run_test_case("gpt4all-j", self.question)
|
48 |
+
|
49 |
+
def test_huggingface(self):
|
50 |
+
self.run_test_case("huggingface", self.question)
|
51 |
+
|
52 |
+
|
53 |
+
class TestQAChain(unittest.TestCase):
|
54 |
+
qa_chain: any
|
55 |
+
question = "What's deep learning?"
|
56 |
+
|
57 |
+
def run_test_case(self, llm_model_type, query):
|
58 |
+
start = timer()
|
59 |
+
os.environ["LLM_MODEL_TYPE"] = llm_model_type
|
60 |
+
qa_chain = app_init()
|
61 |
+
end = timer()
|
62 |
+
print(f"App initialized in {end - start:.3f}s")
|
63 |
+
|
64 |
+
inputs = {"question": query, "chat_history": []}
|
65 |
+
result = qa_chain.call_chain(inputs, None)
|
66 |
+
end2 = timer()
|
67 |
+
print(f"Inference completed in {end2 - end:.3f}s")
|
68 |
+
print_llm_response(result)
|
69 |
+
|
70 |
+
def test_openai(self):
|
71 |
+
self.run_test_case("openai", self.question)
|
72 |
|
73 |
def test_llamacpp(self):
|
74 |
+
self.run_test_case("llamacpp", self.question)
|
75 |
|
76 |
def test_gpt4all_j(self):
|
77 |
+
self.run_test_case("gpt4all-j", self.question)
|
78 |
|
79 |
def test_huggingface(self):
|
80 |
+
self.run_test_case("huggingface", self.question)
|
81 |
|
82 |
|
83 |
if __name__ == "__main__":
|