Enrique Sanchez commited on
Commit
969ccb5
1 Parent(s): 0624de9

improve testing and fixes for the workflow

Browse files
.github/workflows/main.yml CHANGED
@@ -1,7 +1,7 @@
1
  # This workflow will install Python dependencies using Poetry, run tests and lint with a single version of Python using Ruff
2
  # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
 
4
- name: Python application
5
 
6
  on:
7
  push:
@@ -26,7 +26,7 @@ jobs:
26
  ollama pull llama2
27
  - uses: actions/checkout@v4
28
  - name: Set up Python 3.10
29
- uses: actions/setup-python@v3
30
  with:
31
  python-version: "3.10"
32
  - name: Install Poetry
@@ -36,17 +36,9 @@ jobs:
36
  - name: Install dependencies with Poetry
37
  run: |
38
  poetry install
39
- lint:
40
- needs: build
41
- runs-on: ubuntu-latest
42
- steps:
43
  - name: Lint with ruff
44
  run: |
45
- poetry run ruff check
46
- test:
47
- needs: lint
48
- runs-on: ubuntu-latest
49
- steps:
50
  - name: Test with pytest
51
  run: |
52
- poetry run pytest --reruns 1 --reruns-delay 1
 
1
  # This workflow will install Python dependencies using Poetry, run tests and lint with a single version of Python using Ruff
2
  # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
 
4
+ name: Sentiment Analysis for Voice
5
 
6
  on:
7
  push:
 
26
  ollama pull llama2
27
  - uses: actions/checkout@v4
28
  - name: Set up Python 3.10
29
+ uses: actions/setup-python@v4
30
  with:
31
  python-version: "3.10"
32
  - name: Install Poetry
 
36
  - name: Install dependencies with Poetry
37
  run: |
38
  poetry install
 
 
 
 
39
  - name: Lint with ruff
40
  run: |
41
+ poetry run ruff --output-format=github .
 
 
 
 
42
  - name: Test with pytest
43
  run: |
44
+ poetry run pytest --reruns 1 --reruns-delay 1 --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
poetry.lock CHANGED
@@ -487,6 +487,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
487
  test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
488
  test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  [[package]]
491
  name = "ctranslate2"
492
  version = "3.24.0"
@@ -2461,6 +2528,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
2461
  [package.extras]
2462
  testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
2463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2464
  [[package]]
2465
  name = "pytest-rerunfailures"
2466
  version = "13.0"
@@ -3819,4 +3904,4 @@ multidict = ">=4.0"
3819
  [metadata]
3820
  lock-version = "2.0"
3821
  python-versions = ">=3.10,<3.12"
3822
- content-hash = "5a7bfcc29d469fc3c5a947a9d0d2679d925d032a8474a5cc09a4ea258f4c6c6b"
 
487
  test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
488
  test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
489
 
490
+ [[package]]
491
+ name = "coverage"
492
+ version = "7.4.1"
493
+ description = "Code coverage measurement for Python"
494
+ optional = false
495
+ python-versions = ">=3.8"
496
+ files = [
497
+ {file = "coverage-7.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:077d366e724f24fc02dbfe9d946534357fda71af9764ff99d73c3c596001bbd7"},
498
+ {file = "coverage-7.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0193657651f5399d433c92f8ae264aff31fc1d066deee4b831549526433f3f61"},
499
+ {file = "coverage-7.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d17bbc946f52ca67adf72a5ee783cd7cd3477f8f8796f59b4974a9b59cacc9ee"},
500
+ {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3277f5fa7483c927fe3a7b017b39351610265308f5267ac6d4c2b64cc1d8d25"},
501
+ {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dceb61d40cbfcf45f51e59933c784a50846dc03211054bd76b421a713dcdf19"},
502
+ {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6008adeca04a445ea6ef31b2cbaf1d01d02986047606f7da266629afee982630"},
503
+ {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c61f66d93d712f6e03369b6a7769233bfda880b12f417eefdd4f16d1deb2fc4c"},
504
+ {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b9bb62fac84d5f2ff523304e59e5c439955fb3b7f44e3d7b2085184db74d733b"},
505
+ {file = "coverage-7.4.1-cp310-cp310-win32.whl", hash = "sha256:f86f368e1c7ce897bf2457b9eb61169a44e2ef797099fb5728482b8d69f3f016"},
506
+ {file = "coverage-7.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:869b5046d41abfea3e381dd143407b0d29b8282a904a19cb908fa24d090cc018"},
507
+ {file = "coverage-7.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b8ffb498a83d7e0305968289441914154fb0ef5d8b3157df02a90c6695978295"},
508
+ {file = "coverage-7.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3cacfaefe6089d477264001f90f55b7881ba615953414999c46cc9713ff93c8c"},
509
+ {file = "coverage-7.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d6850e6e36e332d5511a48a251790ddc545e16e8beaf046c03985c69ccb2676"},
510
+ {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e961aa13b6d47f758cc5879383d27b5b3f3dcd9ce8cdbfdc2571fe86feb4dd"},
511
+ {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dfd1e1b9f0898817babf840b77ce9fe655ecbe8b1b327983df485b30df8cc011"},
512
+ {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6b00e21f86598b6330f0019b40fb397e705135040dbedc2ca9a93c7441178e74"},
513
+ {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:536d609c6963c50055bab766d9951b6c394759190d03311f3e9fcf194ca909e1"},
514
+ {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7ac8f8eb153724f84885a1374999b7e45734bf93a87d8df1e7ce2146860edef6"},
515
+ {file = "coverage-7.4.1-cp311-cp311-win32.whl", hash = "sha256:f3771b23bb3675a06f5d885c3630b1d01ea6cac9e84a01aaf5508706dba546c5"},
516
+ {file = "coverage-7.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:9d2f9d4cc2a53b38cabc2d6d80f7f9b7e3da26b2f53d48f05876fef7956b6968"},
517
+ {file = "coverage-7.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f68ef3660677e6624c8cace943e4765545f8191313a07288a53d3da188bd8581"},
518
+ {file = "coverage-7.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23b27b8a698e749b61809fb637eb98ebf0e505710ec46a8aa6f1be7dc0dc43a6"},
519
+ {file = "coverage-7.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3424c554391dc9ef4a92ad28665756566a28fecf47308f91841f6c49288e66"},
520
+ {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e0860a348bf7004c812c8368d1fc7f77fe8e4c095d661a579196a9533778e156"},
521
+ {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe558371c1bdf3b8fa03e097c523fb9645b8730399c14fe7721ee9c9e2a545d3"},
522
+ {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3468cc8720402af37b6c6e7e2a9cdb9f6c16c728638a2ebc768ba1ef6f26c3a1"},
523
+ {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:02f2edb575d62172aa28fe00efe821ae31f25dc3d589055b3fb64d51e52e4ab1"},
524
+ {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ca6e61dc52f601d1d224526360cdeab0d0712ec104a2ce6cc5ccef6ed9a233bc"},
525
+ {file = "coverage-7.4.1-cp312-cp312-win32.whl", hash = "sha256:ca7b26a5e456a843b9b6683eada193fc1f65c761b3a473941efe5a291f604c74"},
526
+ {file = "coverage-7.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:85ccc5fa54c2ed64bd91ed3b4a627b9cce04646a659512a051fa82a92c04a448"},
527
+ {file = "coverage-7.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8bdb0285a0202888d19ec6b6d23d5990410decb932b709f2b0dfe216d031d218"},
528
+ {file = "coverage-7.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:918440dea04521f499721c039863ef95433314b1db00ff826a02580c1f503e45"},
529
+ {file = "coverage-7.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379d4c7abad5afbe9d88cc31ea8ca262296480a86af945b08214eb1a556a3e4d"},
530
+ {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b094116f0b6155e36a304ff912f89bbb5067157aff5f94060ff20bbabdc8da06"},
531
+ {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2f5968608b1fe2a1d00d01ad1017ee27efd99b3437e08b83ded9b7af3f6f766"},
532
+ {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:10e88e7f41e6197ea0429ae18f21ff521d4f4490aa33048f6c6f94c6045a6a75"},
533
+ {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a4a3907011d39dbc3e37bdc5df0a8c93853c369039b59efa33a7b6669de04c60"},
534
+ {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d224f0c4c9c98290a6990259073f496fcec1b5cc613eecbd22786d398ded3ad"},
535
+ {file = "coverage-7.4.1-cp38-cp38-win32.whl", hash = "sha256:23f5881362dcb0e1a92b84b3c2809bdc90db892332daab81ad8f642d8ed55042"},
536
+ {file = "coverage-7.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:a07f61fc452c43cd5328b392e52555f7d1952400a1ad09086c4a8addccbd138d"},
537
+ {file = "coverage-7.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8e738a492b6221f8dcf281b67129510835461132b03024830ac0e554311a5c54"},
538
+ {file = "coverage-7.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:46342fed0fff72efcda77040b14728049200cbba1279e0bf1188f1f2078c1d70"},
539
+ {file = "coverage-7.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9641e21670c68c7e57d2053ddf6c443e4f0a6e18e547e86af3fad0795414a628"},
540
+ {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aeb2c2688ed93b027eb0d26aa188ada34acb22dceea256d76390eea135083950"},
541
+ {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1"},
542
+ {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0491275c3b9971cdbd28a4595c2cb5838f08036bca31765bad5e17edf900b2c7"},
543
+ {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8dfc5e195bbef80aabd81596ef52a1277ee7143fe419efc3c4d8ba2754671756"},
544
+ {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1a78b656a4d12b0490ca72651fe4d9f5e07e3c6461063a9b6265ee45eb2bdd35"},
545
+ {file = "coverage-7.4.1-cp39-cp39-win32.whl", hash = "sha256:f90515974b39f4dea2f27c0959688621b46d96d5a626cf9c53dbc653a895c05c"},
546
+ {file = "coverage-7.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:64e723ca82a84053dd7bfcc986bdb34af8d9da83c521c19d6b472bc6880e191a"},
547
+ {file = "coverage-7.4.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:32a8d985462e37cfdab611a6f95b09d7c091d07668fdc26e47a725ee575fe166"},
548
+ {file = "coverage-7.4.1.tar.gz", hash = "sha256:1ed4b95480952b1a26d863e546fa5094564aa0065e1e5f0d4d0041f293251d04"},
549
+ ]
550
+
551
+ [package.dependencies]
552
+ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
553
+
554
+ [package.extras]
555
+ toml = ["tomli"]
556
+
557
  [[package]]
558
  name = "ctranslate2"
559
  version = "3.24.0"
 
2528
  [package.extras]
2529
  testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
2530
 
2531
+ [[package]]
2532
+ name = "pytest-cov"
2533
+ version = "4.1.0"
2534
+ description = "Pytest plugin for measuring coverage."
2535
+ optional = false
2536
+ python-versions = ">=3.7"
2537
+ files = [
2538
+ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
2539
+ {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
2540
+ ]
2541
+
2542
+ [package.dependencies]
2543
+ coverage = {version = ">=5.2.1", extras = ["toml"]}
2544
+ pytest = ">=4.6"
2545
+
2546
+ [package.extras]
2547
+ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
2548
+
2549
  [[package]]
2550
  name = "pytest-rerunfailures"
2551
  version = "13.0"
 
3904
  [metadata]
3905
  lock-version = "2.0"
3906
  python-versions = ">=3.10,<3.12"
3907
+ content-hash = "29545ccc27b7856e223d0565b7fb4f22a4ff6cd3f8b1960b83e3875ab100a871"
pyproject.toml CHANGED
@@ -22,6 +22,7 @@ pytest-rerunfailures = "^13.0"
22
 
23
  [tool.poetry.group.dev.dependencies]
24
  pytest = "^8.0.0"
 
25
 
26
  [build-system]
27
  requires = ["poetry-core"]
 
22
 
23
  [tool.poetry.group.dev.dependencies]
24
  pytest = "^8.0.0"
25
+ pytest-cov = "^4.1.0"
26
 
27
  [build-system]
28
  requires = ["poetry-core"]
src/summarization.py CHANGED
@@ -13,7 +13,7 @@ def topics_for_text(file_conv: str, llm: str = "mistral") -> str:
13
  Returns:
14
  str: The topics for the text.
15
  """
16
- prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in only a list of words that describe the conversation. The list of words should be separated by a comma. The conversation is the following:
17
  "{text}"
18
  TOPICS:"""
19
 
@@ -21,10 +21,36 @@ def topics_for_text(file_conv: str, llm: str = "mistral") -> str:
21
  llm = Ollama(model=llm)
22
 
23
  loader = TextLoader(file_conv)
24
- docs = loader.load()
25
  # Define StuffDocumentsChain
26
  chain = load_summarize_chain(
27
  llm, chain_type="stuff", prompt=prompt, input_key="text"
28
  )
29
- result = chain({"text": docs}, return_only_outputs=True)
30
  return [element.strip().lower() for element in result["output_text"].split(", ")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  Returns:
14
  str: The topics for the text.
15
  """
16
+ prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in only a list of words that describe the conversation. Keep the list short (max 5 items) and every topic has to be described by one word. The list of words should be separated by a comma. The conversation is the following:
17
  "{text}"
18
  TOPICS:"""
19
 
 
21
  llm = Ollama(model=llm)
22
 
23
  loader = TextLoader(file_conv)
24
+ doc = loader.load()
25
  # Define StuffDocumentsChain
26
  chain = load_summarize_chain(
27
  llm, chain_type="stuff", prompt=prompt, input_key="text"
28
  )
29
+ result = chain.invoke({"text": doc}, return_only_outputs=True)
30
  return [element.strip().lower() for element in result["output_text"].split(", ")]
31
+
32
+
33
+ def summarize(file_conv: str, llm: str = "mistral") -> str:
34
+ """Summarize a conversation.
35
+
36
+ Args:
37
+ file_conv (str): The file with the conversation.
38
+
39
+ Returns:
40
+ str: The summary of the conversation.
41
+ """
42
+ prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in one sentence. The conversation is the following:
43
+ "{text}"
44
+ SUMMARY:"""
45
+
46
+ prompt = PromptTemplate.from_template(prompt_template)
47
+ llm = Ollama(model=llm)
48
+
49
+ loader = TextLoader(file_conv)
50
+ doc = loader.load()
51
+ # Define StuffDocumentsChain
52
+ chain = load_summarize_chain(
53
+ llm, chain_type="stuff", prompt=prompt, input_key="text"
54
+ )
55
+ result = chain.invoke({"text": doc}, return_only_outputs=True)
56
+ return result["output_text"]
tests/test_sentiments_and_topics.py DELETED
@@ -1,34 +0,0 @@
1
- import os
2
- import tempfile
3
-
4
- import pytest
5
-
6
- from src.generate_conversation import generate_conversation
7
- from src.sentiment import analyze_sentiment
8
- from src.summarization import topics_for_text
9
-
10
- SENTIMENTS = ["neutral", "sadness", "neutral", "joy"]
11
- TOPICS = ["ai", "soccer", "art", "science"]
12
-
13
-
14
- @pytest.mark.parametrize("sentiment", SENTIMENTS)
15
- @pytest.mark.parametrize("topic", TOPICS)
16
- def test_generate_conversation(sentiment, topic):
17
- # Call the function
18
- conversation = generate_conversation(topic, sentiment, llm="llama2")
19
-
20
- new_sentiment = analyze_sentiment(conversation)
21
- # Assert that the conversation has a positive sentiment
22
-
23
- assert (
24
- sentiment in s for s in new_sentiment
25
- ), "Sentiment is not in the index of new_sentiment"
26
-
27
- # Save the conversation to a temporary text file
28
- with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
29
- file.write(conversation)
30
- temp_filepath = file.name
31
-
32
- new_topics = topics_for_text(temp_filepath)
33
- os.remove(temp_filepath)
34
- assert any(topic in word for word in new_topics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_summarization.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import pytest
5
+ from langchain_community.llms import Ollama
6
+ from langchain_core.output_parsers import JsonOutputParser
7
+ from langchain_core.prompts import PromptTemplate
8
+
9
+ from src.generate_conversation import generate_conversation
10
+ from src.summarization import summarize
11
+ from src.summarization import topics_for_text
12
+
13
+ # SENTIMENTS = ["neutral", "sadness", "joy"]
14
+ # TOPICS = ["ai", "soccer", "art"]
15
+ SENTIMENTS = ["joy"]
16
+ TOPICS = ["ai"]
17
+
18
+
19
+ @pytest.mark.parametrize("sentiment", SENTIMENTS)
20
+ @pytest.mark.parametrize("topic", TOPICS)
21
+ def test_topics_conversation(sentiment, topic):
22
+ # Call the function
23
+ conversation = generate_conversation(topic, sentiment, llm="llama2")
24
+
25
+ # Save the conversation to a temporary text file
26
+ with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
27
+ file.write(conversation)
28
+ temp_filepath = file.name
29
+
30
+ new_topics = topics_for_text(temp_filepath)
31
+ os.remove(temp_filepath)
32
+ assert any(topic in word for word in new_topics)
33
+
34
+
35
+ @pytest.mark.parametrize("sentiment", SENTIMENTS)
36
+ @pytest.mark.parametrize("topic", TOPICS)
37
+ def test_summary_conversation(sentiment, topic):
38
+ # Call the function
39
+ conversation = generate_conversation(topic, sentiment, llm="llama2")
40
+
41
+ # Save the conversation to a temporary text file
42
+ with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
43
+ file.write(conversation)
44
+ temp_filepath = file.name
45
+
46
+ summary = summarize(temp_filepath)
47
+ os.remove(temp_filepath)
48
+
49
+ assert summary != ""
50
+
51
+ model = Ollama(model="llama2")
52
+ prompt = PromptTemplate(
53
+ template="""You will read a summary of a conversation with a sentiment and a topic. Your task is to analyze the conversation and the summary and returns a json object where the key is summary, topic, sentiment and the value is True if the sentiment and the topic are correct and False otherwise. The conversation is the following: {conversation} The summary is the following: {summary}, the topic is {topic} and the sentiment is the following: {sentiment}
54
+ JSON:""",
55
+ input_variables=["conversation", "summary", "topic", "sentiment"],
56
+ )
57
+ output_parser = JsonOutputParser()
58
+ chain = prompt | model | output_parser
59
+ result = chain.invoke(
60
+ {
61
+ "conversation": conversation,
62
+ "topic": topic,
63
+ "summary": summary,
64
+ "sentiment": sentiment,
65
+ }
66
+ )
67
+ assert result["summary"], "The summary is not correct"
68
+ assert result["topic"], "The topic is not correct"
69
+ assert result["sentiment"], "The sentiment is not correct"