File size: 2,032 Bytes
7f9d16c 815128e c2cb992 7f9d16c c2cb992 815128e 7f9d16c 815128e 7f9d16c c2cb992 815128e c2cb992 815128e 7f9d16c 815128e 7f9d16c c2cb992 7f9d16c 815128e 7f9d16c 815128e c2cb992 815128e c2cb992 815128e c2cb992 815128e 7f9d16c c2cb992 815128e 7f9d16c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# project/test.py
import os
import unittest
from timeit import default_timer as timer
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import HumanMessage
from app_modules.llm_loader import LLMLoader
from app_modules.utils import *
user_question = "What's the capital city of Malaysia?"
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
class MyCustomHandler(BaseCallbackHandler):
def __init__(self):
self.reset()
def reset(self):
self.texts = []
def get_standalone_question(self) -> str:
return self.texts[0].strip() if len(self.texts) > 0 else None
def on_llm_end(self, response, **kwargs) -> None:
"""Run when chain ends running."""
print("\non_llm_end - response:")
print(response)
self.texts.append(response.generations[0][0].text)
class TestLLMLoader(unittest.TestCase):
def run_test_case(self, llm_model_type, query):
llm_loader = LLMLoader(llm_model_type)
start = timer()
llm_loader.init(
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
)
end = timer()
print(f"Model loaded in {end - start:.3f}s")
result = llm_loader.llm(
[HumanMessage(content=query)] if llm_model_type == "openai" else query
)
end2 = timer()
print(f"Inference completed in {end2 - end:.3f}s")
print(result)
def test_openai(self):
self.run_test_case("openai", user_question)
def test_llamacpp(self):
self.run_test_case("llamacpp", user_question)
def test_gpt4all_j(self):
self.run_test_case("gpt4all-j", user_question)
def test_huggingface(self):
self.run_test_case("huggingface", user_question)
if __name__ == "__main__":
unittest.main()
|