File size: 3,422 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os

import pytest
from injector import Injector

from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.llm import QWenService
from taskweaver.llm.ollama import OllamaService
from taskweaver.llm.openai import OpenAIService
from taskweaver.llm.sentence_transformer import SentenceTransformerService

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_sentence_transformer_embedding():
    app_injector = Injector([])
    app_config = AppConfigSource(
        config={
            "llm.embedding_api_type": "sentence_transformer",
            "llm.embedding_model": "all-mpnet-base-v2",
        },
    )
    app_injector.binder.bind(AppConfigSource, to=app_config)
    sentence_transformer_service = app_injector.create_object(
        SentenceTransformerService,
    )

    text_list = ["This is a test sentence.", "This is another test sentence."]
    embedding1 = sentence_transformer_service.get_embeddings(text_list)

    assert len(embedding1) == 2
    assert len(embedding1[0]) == 768
    assert len(embedding1[1]) == 768


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_openai_embedding():
    app_injector = Injector()
    app_config = AppConfigSource(
        config={
            "llm.embedding_api_type": "openai",
            "llm.embedding_model": "text-embedding-ada-002",
            "llm.api_key": "",
            # need to configure llm.api_key in the config to run this test
        },
    )
    app_injector.binder.bind(AppConfigSource, to=app_config)
    openai_service = app_injector.create_object(OpenAIService)

    text_list = ["This is a test sentence.", "This is another test sentence."]
    embedding1 = openai_service.get_embeddings(text_list)

    assert len(embedding1) == 2
    assert len(embedding1[0]) == 1536
    assert len(embedding1[1]) == 1536


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_ollama_embedding():
    app_injector = Injector()
    app_config = AppConfigSource(
        config={
            "llm.embedding_api_type": "ollama",
            "llm.embedding_model": "llama2",
        },
    )
    app_injector.binder.bind(AppConfigSource, to=app_config)
    ollama_service = app_injector.create_object(OllamaService)

    text_list = ["This is a test sentence.", "This is another test sentence."]
    embedding1 = ollama_service.get_embeddings(text_list)

    assert len(embedding1) == 2
    assert len(embedding1[0]) == 4096
    assert len(embedding1[1]) == 4096


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_qwen_embedding():
    app_injector = Injector()
    app_config = AppConfigSource(
        config={
            "llm.embedding_api_type": "qwen",
            "llm.embedding_model": "text-embedding-v1",
            "llm.api_key": "",
            # need to configure llm.api_key in the config to run this test
        },
    )
    app_injector.binder.bind(AppConfigSource, to=app_config)
    qwen_service = app_injector.create_object(QWenService)

    text_list = ["This is a test sentence.", "This is another test sentence."]
    embeddings = qwen_service.get_embeddings(text_list)

    assert len(embeddings) == 2
    assert len(embeddings[0]) == 1536
    assert len(embeddings[1]) == 1536