Spaces:
Configuration error
Configuration error
Enrique Sanchez
commited on
Commit
•
969ccb5
1
Parent(s):
0624de9
improve testing and fixes for the workflow
Browse files- .github/workflows/main.yml +4 -12
- poetry.lock +86 -1
- pyproject.toml +1 -0
- src/summarization.py +29 -3
- tests/test_sentiments_and_topics.py +0 -34
- tests/test_summarization.py +69 -0
.github/workflows/main.yml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# This workflow will install Python dependencies using Poetry, run tests and lint with a single version of Python using Ruff
|
2 |
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
3 |
|
4 |
-
name:
|
5 |
|
6 |
on:
|
7 |
push:
|
@@ -26,7 +26,7 @@ jobs:
|
|
26 |
ollama pull llama2
|
27 |
- uses: actions/checkout@v4
|
28 |
- name: Set up Python 3.10
|
29 |
-
uses: actions/setup-python@
|
30 |
with:
|
31 |
python-version: "3.10"
|
32 |
- name: Install Poetry
|
@@ -36,17 +36,9 @@ jobs:
|
|
36 |
- name: Install dependencies with Poetry
|
37 |
run: |
|
38 |
poetry install
|
39 |
-
lint:
|
40 |
-
needs: build
|
41 |
-
runs-on: ubuntu-latest
|
42 |
-
steps:
|
43 |
- name: Lint with ruff
|
44 |
run: |
|
45 |
-
poetry run ruff
|
46 |
-
test:
|
47 |
-
needs: lint
|
48 |
-
runs-on: ubuntu-latest
|
49 |
-
steps:
|
50 |
- name: Test with pytest
|
51 |
run: |
|
52 |
-
poetry run pytest --reruns 1 --reruns-delay 1
|
|
|
1 |
# This workflow will install Python dependencies using Poetry, run tests and lint with a single version of Python using Ruff
|
2 |
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
3 |
|
4 |
+
name: Sentiment Analysis for Voice
|
5 |
|
6 |
on:
|
7 |
push:
|
|
|
26 |
ollama pull llama2
|
27 |
- uses: actions/checkout@v4
|
28 |
- name: Set up Python 3.10
|
29 |
+
uses: actions/setup-python@v4
|
30 |
with:
|
31 |
python-version: "3.10"
|
32 |
- name: Install Poetry
|
|
|
36 |
- name: Install dependencies with Poetry
|
37 |
run: |
|
38 |
poetry install
|
|
|
|
|
|
|
|
|
39 |
- name: Lint with ruff
|
40 |
run: |
|
41 |
+
poetry run ruff --output-format=github .
|
|
|
|
|
|
|
|
|
42 |
- name: Test with pytest
|
43 |
run: |
|
44 |
+
poetry run pytest --reruns 1 --reruns-delay 1 --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
|
poetry.lock
CHANGED
@@ -487,6 +487,73 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.6.1)", "types-Pill
|
|
487 |
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
488 |
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
[[package]]
|
491 |
name = "ctranslate2"
|
492 |
version = "3.24.0"
|
@@ -2461,6 +2528,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
|
2461 |
[package.extras]
|
2462 |
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
2463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2464 |
[[package]]
|
2465 |
name = "pytest-rerunfailures"
|
2466 |
version = "13.0"
|
@@ -3819,4 +3904,4 @@ multidict = ">=4.0"
|
|
3819 |
[metadata]
|
3820 |
lock-version = "2.0"
|
3821 |
python-versions = ">=3.10,<3.12"
|
3822 |
-
content-hash = "
|
|
|
487 |
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
|
488 |
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
|
489 |
|
490 |
+
[[package]]
|
491 |
+
name = "coverage"
|
492 |
+
version = "7.4.1"
|
493 |
+
description = "Code coverage measurement for Python"
|
494 |
+
optional = false
|
495 |
+
python-versions = ">=3.8"
|
496 |
+
files = [
|
497 |
+
{file = "coverage-7.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:077d366e724f24fc02dbfe9d946534357fda71af9764ff99d73c3c596001bbd7"},
|
498 |
+
{file = "coverage-7.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0193657651f5399d433c92f8ae264aff31fc1d066deee4b831549526433f3f61"},
|
499 |
+
{file = "coverage-7.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d17bbc946f52ca67adf72a5ee783cd7cd3477f8f8796f59b4974a9b59cacc9ee"},
|
500 |
+
{file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3277f5fa7483c927fe3a7b017b39351610265308f5267ac6d4c2b64cc1d8d25"},
|
501 |
+
{file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dceb61d40cbfcf45f51e59933c784a50846dc03211054bd76b421a713dcdf19"},
|
502 |
+
{file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6008adeca04a445ea6ef31b2cbaf1d01d02986047606f7da266629afee982630"},
|
503 |
+
{file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c61f66d93d712f6e03369b6a7769233bfda880b12f417eefdd4f16d1deb2fc4c"},
|
504 |
+
{file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b9bb62fac84d5f2ff523304e59e5c439955fb3b7f44e3d7b2085184db74d733b"},
|
505 |
+
{file = "coverage-7.4.1-cp310-cp310-win32.whl", hash = "sha256:f86f368e1c7ce897bf2457b9eb61169a44e2ef797099fb5728482b8d69f3f016"},
|
506 |
+
{file = "coverage-7.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:869b5046d41abfea3e381dd143407b0d29b8282a904a19cb908fa24d090cc018"},
|
507 |
+
{file = "coverage-7.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b8ffb498a83d7e0305968289441914154fb0ef5d8b3157df02a90c6695978295"},
|
508 |
+
{file = "coverage-7.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3cacfaefe6089d477264001f90f55b7881ba615953414999c46cc9713ff93c8c"},
|
509 |
+
{file = "coverage-7.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d6850e6e36e332d5511a48a251790ddc545e16e8beaf046c03985c69ccb2676"},
|
510 |
+
{file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e961aa13b6d47f758cc5879383d27b5b3f3dcd9ce8cdbfdc2571fe86feb4dd"},
|
511 |
+
{file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dfd1e1b9f0898817babf840b77ce9fe655ecbe8b1b327983df485b30df8cc011"},
|
512 |
+
{file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6b00e21f86598b6330f0019b40fb397e705135040dbedc2ca9a93c7441178e74"},
|
513 |
+
{file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:536d609c6963c50055bab766d9951b6c394759190d03311f3e9fcf194ca909e1"},
|
514 |
+
{file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7ac8f8eb153724f84885a1374999b7e45734bf93a87d8df1e7ce2146860edef6"},
|
515 |
+
{file = "coverage-7.4.1-cp311-cp311-win32.whl", hash = "sha256:f3771b23bb3675a06f5d885c3630b1d01ea6cac9e84a01aaf5508706dba546c5"},
|
516 |
+
{file = "coverage-7.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:9d2f9d4cc2a53b38cabc2d6d80f7f9b7e3da26b2f53d48f05876fef7956b6968"},
|
517 |
+
{file = "coverage-7.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f68ef3660677e6624c8cace943e4765545f8191313a07288a53d3da188bd8581"},
|
518 |
+
{file = "coverage-7.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23b27b8a698e749b61809fb637eb98ebf0e505710ec46a8aa6f1be7dc0dc43a6"},
|
519 |
+
{file = "coverage-7.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3424c554391dc9ef4a92ad28665756566a28fecf47308f91841f6c49288e66"},
|
520 |
+
{file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e0860a348bf7004c812c8368d1fc7f77fe8e4c095d661a579196a9533778e156"},
|
521 |
+
{file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe558371c1bdf3b8fa03e097c523fb9645b8730399c14fe7721ee9c9e2a545d3"},
|
522 |
+
{file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3468cc8720402af37b6c6e7e2a9cdb9f6c16c728638a2ebc768ba1ef6f26c3a1"},
|
523 |
+
{file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:02f2edb575d62172aa28fe00efe821ae31f25dc3d589055b3fb64d51e52e4ab1"},
|
524 |
+
{file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ca6e61dc52f601d1d224526360cdeab0d0712ec104a2ce6cc5ccef6ed9a233bc"},
|
525 |
+
{file = "coverage-7.4.1-cp312-cp312-win32.whl", hash = "sha256:ca7b26a5e456a843b9b6683eada193fc1f65c761b3a473941efe5a291f604c74"},
|
526 |
+
{file = "coverage-7.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:85ccc5fa54c2ed64bd91ed3b4a627b9cce04646a659512a051fa82a92c04a448"},
|
527 |
+
{file = "coverage-7.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8bdb0285a0202888d19ec6b6d23d5990410decb932b709f2b0dfe216d031d218"},
|
528 |
+
{file = "coverage-7.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:918440dea04521f499721c039863ef95433314b1db00ff826a02580c1f503e45"},
|
529 |
+
{file = "coverage-7.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379d4c7abad5afbe9d88cc31ea8ca262296480a86af945b08214eb1a556a3e4d"},
|
530 |
+
{file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b094116f0b6155e36a304ff912f89bbb5067157aff5f94060ff20bbabdc8da06"},
|
531 |
+
{file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2f5968608b1fe2a1d00d01ad1017ee27efd99b3437e08b83ded9b7af3f6f766"},
|
532 |
+
{file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:10e88e7f41e6197ea0429ae18f21ff521d4f4490aa33048f6c6f94c6045a6a75"},
|
533 |
+
{file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a4a3907011d39dbc3e37bdc5df0a8c93853c369039b59efa33a7b6669de04c60"},
|
534 |
+
{file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d224f0c4c9c98290a6990259073f496fcec1b5cc613eecbd22786d398ded3ad"},
|
535 |
+
{file = "coverage-7.4.1-cp38-cp38-win32.whl", hash = "sha256:23f5881362dcb0e1a92b84b3c2809bdc90db892332daab81ad8f642d8ed55042"},
|
536 |
+
{file = "coverage-7.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:a07f61fc452c43cd5328b392e52555f7d1952400a1ad09086c4a8addccbd138d"},
|
537 |
+
{file = "coverage-7.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8e738a492b6221f8dcf281b67129510835461132b03024830ac0e554311a5c54"},
|
538 |
+
{file = "coverage-7.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:46342fed0fff72efcda77040b14728049200cbba1279e0bf1188f1f2078c1d70"},
|
539 |
+
{file = "coverage-7.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9641e21670c68c7e57d2053ddf6c443e4f0a6e18e547e86af3fad0795414a628"},
|
540 |
+
{file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aeb2c2688ed93b027eb0d26aa188ada34acb22dceea256d76390eea135083950"},
|
541 |
+
{file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1"},
|
542 |
+
{file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0491275c3b9971cdbd28a4595c2cb5838f08036bca31765bad5e17edf900b2c7"},
|
543 |
+
{file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8dfc5e195bbef80aabd81596ef52a1277ee7143fe419efc3c4d8ba2754671756"},
|
544 |
+
{file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1a78b656a4d12b0490ca72651fe4d9f5e07e3c6461063a9b6265ee45eb2bdd35"},
|
545 |
+
{file = "coverage-7.4.1-cp39-cp39-win32.whl", hash = "sha256:f90515974b39f4dea2f27c0959688621b46d96d5a626cf9c53dbc653a895c05c"},
|
546 |
+
{file = "coverage-7.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:64e723ca82a84053dd7bfcc986bdb34af8d9da83c521c19d6b472bc6880e191a"},
|
547 |
+
{file = "coverage-7.4.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:32a8d985462e37cfdab611a6f95b09d7c091d07668fdc26e47a725ee575fe166"},
|
548 |
+
{file = "coverage-7.4.1.tar.gz", hash = "sha256:1ed4b95480952b1a26d863e546fa5094564aa0065e1e5f0d4d0041f293251d04"},
|
549 |
+
]
|
550 |
+
|
551 |
+
[package.dependencies]
|
552 |
+
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
553 |
+
|
554 |
+
[package.extras]
|
555 |
+
toml = ["tomli"]
|
556 |
+
|
557 |
[[package]]
|
558 |
name = "ctranslate2"
|
559 |
version = "3.24.0"
|
|
|
2528 |
[package.extras]
|
2529 |
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
2530 |
|
2531 |
+
[[package]]
|
2532 |
+
name = "pytest-cov"
|
2533 |
+
version = "4.1.0"
|
2534 |
+
description = "Pytest plugin for measuring coverage."
|
2535 |
+
optional = false
|
2536 |
+
python-versions = ">=3.7"
|
2537 |
+
files = [
|
2538 |
+
{file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
|
2539 |
+
{file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
|
2540 |
+
]
|
2541 |
+
|
2542 |
+
[package.dependencies]
|
2543 |
+
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
2544 |
+
pytest = ">=4.6"
|
2545 |
+
|
2546 |
+
[package.extras]
|
2547 |
+
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
|
2548 |
+
|
2549 |
[[package]]
|
2550 |
name = "pytest-rerunfailures"
|
2551 |
version = "13.0"
|
|
|
3904 |
[metadata]
|
3905 |
lock-version = "2.0"
|
3906 |
python-versions = ">=3.10,<3.12"
|
3907 |
+
content-hash = "29545ccc27b7856e223d0565b7fb4f22a4ff6cd3f8b1960b83e3875ab100a871"
|
pyproject.toml
CHANGED
@@ -22,6 +22,7 @@ pytest-rerunfailures = "^13.0"
|
|
22 |
|
23 |
[tool.poetry.group.dev.dependencies]
|
24 |
pytest = "^8.0.0"
|
|
|
25 |
|
26 |
[build-system]
|
27 |
requires = ["poetry-core"]
|
|
|
22 |
|
23 |
[tool.poetry.group.dev.dependencies]
|
24 |
pytest = "^8.0.0"
|
25 |
+
pytest-cov = "^4.1.0"
|
26 |
|
27 |
[build-system]
|
28 |
requires = ["poetry-core"]
|
src/summarization.py
CHANGED
@@ -13,7 +13,7 @@ def topics_for_text(file_conv: str, llm: str = "mistral") -> str:
|
|
13 |
Returns:
|
14 |
str: The topics for the text.
|
15 |
"""
|
16 |
-
prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in only a list of words that describe the conversation. The list of words should be separated by a comma. The conversation is the following:
|
17 |
"{text}"
|
18 |
TOPICS:"""
|
19 |
|
@@ -21,10 +21,36 @@ def topics_for_text(file_conv: str, llm: str = "mistral") -> str:
|
|
21 |
llm = Ollama(model=llm)
|
22 |
|
23 |
loader = TextLoader(file_conv)
|
24 |
-
|
25 |
# Define StuffDocumentsChain
|
26 |
chain = load_summarize_chain(
|
27 |
llm, chain_type="stuff", prompt=prompt, input_key="text"
|
28 |
)
|
29 |
-
result = chain({"text":
|
30 |
return [element.strip().lower() for element in result["output_text"].split(", ")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
Returns:
|
14 |
str: The topics for the text.
|
15 |
"""
|
16 |
+
prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in only a list of words that describe the conversation. Keep the list short (max 5 items) and every topic has to be described by one word. The list of words should be separated by a comma. The conversation is the following:
|
17 |
"{text}"
|
18 |
TOPICS:"""
|
19 |
|
|
|
21 |
llm = Ollama(model=llm)
|
22 |
|
23 |
loader = TextLoader(file_conv)
|
24 |
+
doc = loader.load()
|
25 |
# Define StuffDocumentsChain
|
26 |
chain = load_summarize_chain(
|
27 |
llm, chain_type="stuff", prompt=prompt, input_key="text"
|
28 |
)
|
29 |
+
result = chain.invoke({"text": doc}, return_only_outputs=True)
|
30 |
return [element.strip().lower() for element in result["output_text"].split(", ")]
|
31 |
+
|
32 |
+
|
33 |
+
def summarize(file_conv: str, llm: str = "mistral") -> str:
|
34 |
+
"""Summarize a conversation.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
file_conv (str): The file with the conversation.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
str: The summary of the conversation.
|
41 |
+
"""
|
42 |
+
prompt_template = """The next text is a conversation between two people about a topic. Your task is to summarize the conversation in one sentence. The conversation is the following:
|
43 |
+
"{text}"
|
44 |
+
SUMMARY:"""
|
45 |
+
|
46 |
+
prompt = PromptTemplate.from_template(prompt_template)
|
47 |
+
llm = Ollama(model=llm)
|
48 |
+
|
49 |
+
loader = TextLoader(file_conv)
|
50 |
+
doc = loader.load()
|
51 |
+
# Define StuffDocumentsChain
|
52 |
+
chain = load_summarize_chain(
|
53 |
+
llm, chain_type="stuff", prompt=prompt, input_key="text"
|
54 |
+
)
|
55 |
+
result = chain.invoke({"text": doc}, return_only_outputs=True)
|
56 |
+
return result["output_text"]
|
tests/test_sentiments_and_topics.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import tempfile
|
3 |
-
|
4 |
-
import pytest
|
5 |
-
|
6 |
-
from src.generate_conversation import generate_conversation
|
7 |
-
from src.sentiment import analyze_sentiment
|
8 |
-
from src.summarization import topics_for_text
|
9 |
-
|
10 |
-
SENTIMENTS = ["neutral", "sadness", "neutral", "joy"]
|
11 |
-
TOPICS = ["ai", "soccer", "art", "science"]
|
12 |
-
|
13 |
-
|
14 |
-
@pytest.mark.parametrize("sentiment", SENTIMENTS)
|
15 |
-
@pytest.mark.parametrize("topic", TOPICS)
|
16 |
-
def test_generate_conversation(sentiment, topic):
|
17 |
-
# Call the function
|
18 |
-
conversation = generate_conversation(topic, sentiment, llm="llama2")
|
19 |
-
|
20 |
-
new_sentiment = analyze_sentiment(conversation)
|
21 |
-
# Assert that the conversation has a positive sentiment
|
22 |
-
|
23 |
-
assert (
|
24 |
-
sentiment in s for s in new_sentiment
|
25 |
-
), "Sentiment is not in the index of new_sentiment"
|
26 |
-
|
27 |
-
# Save the conversation to a temporary text file
|
28 |
-
with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
|
29 |
-
file.write(conversation)
|
30 |
-
temp_filepath = file.name
|
31 |
-
|
32 |
-
new_topics = topics_for_text(temp_filepath)
|
33 |
-
os.remove(temp_filepath)
|
34 |
-
assert any(topic in word for word in new_topics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_summarization.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
from langchain_community.llms import Ollama
|
6 |
+
from langchain_core.output_parsers import JsonOutputParser
|
7 |
+
from langchain_core.prompts import PromptTemplate
|
8 |
+
|
9 |
+
from src.generate_conversation import generate_conversation
|
10 |
+
from src.summarization import summarize
|
11 |
+
from src.summarization import topics_for_text
|
12 |
+
|
13 |
+
# SENTIMENTS = ["neutral", "sadness", "joy"]
|
14 |
+
# TOPICS = ["ai", "soccer", "art"]
|
15 |
+
SENTIMENTS = ["joy"]
|
16 |
+
TOPICS = ["ai"]
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.mark.parametrize("sentiment", SENTIMENTS)
|
20 |
+
@pytest.mark.parametrize("topic", TOPICS)
|
21 |
+
def test_topics_conversation(sentiment, topic):
|
22 |
+
# Call the function
|
23 |
+
conversation = generate_conversation(topic, sentiment, llm="llama2")
|
24 |
+
|
25 |
+
# Save the conversation to a temporary text file
|
26 |
+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
|
27 |
+
file.write(conversation)
|
28 |
+
temp_filepath = file.name
|
29 |
+
|
30 |
+
new_topics = topics_for_text(temp_filepath)
|
31 |
+
os.remove(temp_filepath)
|
32 |
+
assert any(topic in word for word in new_topics)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.mark.parametrize("sentiment", SENTIMENTS)
|
36 |
+
@pytest.mark.parametrize("topic", TOPICS)
|
37 |
+
def test_summary_conversation(sentiment, topic):
|
38 |
+
# Call the function
|
39 |
+
conversation = generate_conversation(topic, sentiment, llm="llama2")
|
40 |
+
|
41 |
+
# Save the conversation to a temporary text file
|
42 |
+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
|
43 |
+
file.write(conversation)
|
44 |
+
temp_filepath = file.name
|
45 |
+
|
46 |
+
summary = summarize(temp_filepath)
|
47 |
+
os.remove(temp_filepath)
|
48 |
+
|
49 |
+
assert summary != ""
|
50 |
+
|
51 |
+
model = Ollama(model="llama2")
|
52 |
+
prompt = PromptTemplate(
|
53 |
+
template="""You will read a summary of a conversation with a sentiment and a topic. Your task is to analyze the conversation and the summary and returns a json object where the key is summary, topic, sentiment and the value is True if the sentiment and the topic are correct and False otherwise. The conversation is the following: {conversation} The summary is the following: {summary}, the topic is {topic} and the sentiment is the following: {sentiment}
|
54 |
+
JSON:""",
|
55 |
+
input_variables=["conversation", "summary", "topic", "sentiment"],
|
56 |
+
)
|
57 |
+
output_parser = JsonOutputParser()
|
58 |
+
chain = prompt | model | output_parser
|
59 |
+
result = chain.invoke(
|
60 |
+
{
|
61 |
+
"conversation": conversation,
|
62 |
+
"topic": topic,
|
63 |
+
"summary": summary,
|
64 |
+
"sentiment": sentiment,
|
65 |
+
}
|
66 |
+
)
|
67 |
+
assert result["summary"], "The summary is not correct"
|
68 |
+
assert result["topic"], "The topic is not correct"
|
69 |
+
assert result["sentiment"], "The sentiment is not correct"
|