File size: 6,163 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pytest

from tests.utils import wrap_test_forked, get_llama
from src.enums import DocumentSubset


@wrap_test_forked
def test_cli(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in our solar system" in all_generations[0]


@pytest.mark.parametrize("base_model", ['gptj', 'gpt4all_llama'])
@wrap_test_forked
def test_cli_langchain(base_model, monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations = main(base_model=base_model, cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           user_path=user_path,
                           langchain_modes=['UserData', 'MyData'],
                           document_subset=DocumentSubset.Relevant.name,
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "what the cat is doing" in all_generations[0] or \
           "question about a cat" in all_generations[0] or \
           "The prompt asks for an answer to a question" in all_generations[0] or \
           "The prompt asks what the cat in the scenario is doing" in all_generations[0] or \
           "The prompt asks why H2O.ai" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           model_path_llama=full_path,
                           prompt_type=prompt_type,
                           user_path=user_path,
                           langchain_modes=['UserData', 'MyData'],
                           document_subset=DocumentSubset.Relevant.name,
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "The cat is sitting on a window seat and looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "The cat is likely relaxing and enjoying" in all_generations[0] or \
           "The cat is sitting on a window seat and looking out" in all_generations[0] or \
           "cat in the image is" in all_generations[0] or \
           "The cat is sitting on a window" in all_generations[0] or \
           "The cat is sitting and looking out the window at the view of the city outside." in all_generations[0] or \
           "cat is sitting on a window sill" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
    prompt_type, full_path = get_llama()

    query = "Who are you?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    langchain_mode = 'Disabled'
    all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                           langchain_mode=langchain_mode,
                           prompt_type=prompt_type,
                           model_path_llama=full_path,
                           user_path=None,
                           langchain_modes=[langchain_mode],
                           document_subset=DocumentSubset.Relevant.name,
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
           "how can I assist" in all_generations[0] or \
           "am a virtual assistant" in all_generations[0] or \
           "My name is John." in all_generations[0] or \
           "I am a student" in all_generations[0]  or \
           "I'm LLaMA" in all_generations[0]


@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in the Solar System" in all_generations[0] or \
           "The Earth is the third planet" in all_generations[0]


@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from src.gen import main
    all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
                           cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           user_path=user_path,
                           langchain_modes=['UserData', 'MyData'],
                           document_subset=DocumentSubset.Relevant.name,
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or "staring out the window at the city skyline" in \
           all_generations[0]