menikev commited on
Commit
ec4a718
1 Parent(s): d2ed505

Upload full_inference_pipeline.ipynb

Browse files
notebook/full_inference_pipeline.ipynb ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "source": [
20
+ "! pip install faknow sentence-transformers chromadb\n"
21
+ ],
22
+ "metadata": {
23
+ "colab": {
24
+ "base_uri": "https://localhost:8080/"
25
+ },
26
+ "id": "83T0FpMEgAK7",
27
+ "outputId": "4efafed8-69d4-4575-b473-825e6931b4c5"
28
+ },
29
+ "execution_count": 27,
30
+ "outputs": [
31
+ {
32
+ "output_type": "stream",
33
+ "name": "stdout",
34
+ "text": [
35
+ "Requirement already satisfied: faknow in /usr/local/lib/python3.10/dist-packages (0.0.3)\n",
36
+ "Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (2.6.1)\n",
37
+ "Requirement already satisfied: chromadb in /usr/local/lib/python3.10/dist-packages (0.4.24)\n",
38
+ "Requirement already satisfied: transformers>=4.26.1 in /usr/local/lib/python3.10/dist-packages (from faknow) (4.38.2)\n",
39
+ "Requirement already satisfied: numpy>=1.23.4 in /usr/local/lib/python3.10/dist-packages (from faknow) (1.25.2)\n",
40
+ "Requirement already satisfied: pandas>=1.5.2 in /usr/local/lib/python3.10/dist-packages (from faknow) (1.5.3)\n",
41
+ "Requirement already satisfied: scikit-learn>=1.1.3 in /usr/local/lib/python3.10/dist-packages (from faknow) (1.2.2)\n",
42
+ "Requirement already satisfied: tensorboard>=2.10.0 in /usr/local/lib/python3.10/dist-packages (from faknow) (2.15.2)\n",
43
+ "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.10/dist-packages (from faknow) (4.66.2)\n",
44
+ "Requirement already satisfied: jieba>=0.42.1 in /usr/local/lib/python3.10/dist-packages (from faknow) (0.42.1)\n",
45
+ "Requirement already satisfied: gensim>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from faknow) (4.3.2)\n",
46
+ "Requirement already satisfied: pillow>=9.3.0 in /usr/local/lib/python3.10/dist-packages (from faknow) (9.4.0)\n",
47
+ "Requirement already satisfied: nltk>=3.7 in /usr/local/lib/python3.10/dist-packages (from faknow) (3.8.1)\n",
48
+ "Requirement already satisfied: sphinx-markdown-tables>=0.0.17 in /usr/local/lib/python3.10/dist-packages (from faknow) (0.0.17)\n",
49
+ "Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (2.2.1+cu121)\n",
50
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (1.11.4)\n",
51
+ "Requirement already satisfied: huggingface-hub>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from sentence-transformers) (0.20.3)\n",
52
+ "Requirement already satisfied: build>=1.0.3 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.2.1)\n",
53
+ "Requirement already satisfied: requests>=2.28 in /usr/local/lib/python3.10/dist-packages (from chromadb) (2.31.0)\n",
54
+ "Requirement already satisfied: pydantic>=1.9 in /usr/local/lib/python3.10/dist-packages (from chromadb) (2.6.4)\n",
55
+ "Requirement already satisfied: chroma-hnswlib==0.7.3 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.7.3)\n",
56
+ "Requirement already satisfied: fastapi>=0.95.2 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.110.0)\n",
57
+ "Requirement already satisfied: uvicorn[standard]>=0.18.3 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.29.0)\n",
58
+ "Requirement already satisfied: posthog>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (3.5.0)\n",
59
+ "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (4.10.0)\n",
60
+ "Requirement already satisfied: pulsar-client>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (3.4.0)\n",
61
+ "Requirement already satisfied: onnxruntime>=1.14.1 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.17.1)\n",
62
+ "Requirement already satisfied: opentelemetry-api>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.24.0)\n",
63
+ "Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.24.0)\n",
64
+ "Requirement already satisfied: opentelemetry-instrumentation-fastapi>=0.41b0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.45b0)\n",
65
+ "Requirement already satisfied: opentelemetry-sdk>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.24.0)\n",
66
+ "Requirement already satisfied: tokenizers>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.15.2)\n",
67
+ "Requirement already satisfied: pypika>=0.48.9 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.48.9)\n",
68
+ "Requirement already satisfied: overrides>=7.3.1 in /usr/local/lib/python3.10/dist-packages (from chromadb) (7.7.0)\n",
69
+ "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.10/dist-packages (from chromadb) (6.4.0)\n",
70
+ "Requirement already satisfied: grpcio>=1.58.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (1.62.1)\n",
71
+ "Requirement already satisfied: bcrypt>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from chromadb) (4.1.2)\n",
72
+ "Requirement already satisfied: typer>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (0.9.4)\n",
73
+ "Requirement already satisfied: kubernetes>=28.1.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (29.0.0)\n",
74
+ "Requirement already satisfied: tenacity>=8.2.3 in /usr/local/lib/python3.10/dist-packages (from chromadb) (8.2.3)\n",
75
+ "Requirement already satisfied: PyYAML>=6.0.0 in /usr/local/lib/python3.10/dist-packages (from chromadb) (6.0.1)\n",
76
+ "Requirement already satisfied: mmh3>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from chromadb) (4.1.0)\n",
77
+ "Requirement already satisfied: orjson>=3.9.12 in /usr/local/lib/python3.10/dist-packages (from chromadb) (3.10.0)\n",
78
+ "Requirement already satisfied: packaging>=19.1 in /usr/local/lib/python3.10/dist-packages (from build>=1.0.3->chromadb) (24.0)\n",
79
+ "Requirement already satisfied: pyproject_hooks in /usr/local/lib/python3.10/dist-packages (from build>=1.0.3->chromadb) (1.0.0)\n",
80
+ "Requirement already satisfied: tomli>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from build>=1.0.3->chromadb) (2.0.1)\n",
81
+ "Requirement already satisfied: starlette<0.37.0,>=0.36.3 in /usr/local/lib/python3.10/dist-packages (from fastapi>=0.95.2->chromadb) (0.36.3)\n",
82
+ "Requirement already satisfied: smart-open>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from gensim>=4.2.0->faknow) (6.4.0)\n",
83
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (3.13.3)\n",
84
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence-transformers) (2023.6.0)\n",
85
+ "Requirement already satisfied: certifi>=14.05.14 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (2024.2.2)\n",
86
+ "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (1.16.0)\n",
87
+ "Requirement already satisfied: python-dateutil>=2.5.3 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (2.8.2)\n",
88
+ "Requirement already satisfied: google-auth>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (2.27.0)\n",
89
+ "Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (1.7.0)\n",
90
+ "Requirement already satisfied: requests-oauthlib in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (1.4.1)\n",
91
+ "Requirement already satisfied: oauthlib>=3.2.2 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (3.2.2)\n",
92
+ "Requirement already satisfied: urllib3>=1.24.2 in /usr/local/lib/python3.10/dist-packages (from kubernetes>=28.1.0->chromadb) (2.0.7)\n",
93
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk>=3.7->faknow) (8.1.7)\n",
94
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk>=3.7->faknow) (1.3.2)\n",
95
+ "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk>=3.7->faknow) (2023.12.25)\n",
96
+ "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.14.1->chromadb) (15.0.1)\n",
97
+ "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.14.1->chromadb) (24.3.25)\n",
98
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.14.1->chromadb) (3.20.3)\n",
99
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.14.1->chromadb) (1.12)\n",
100
+ "Requirement already satisfied: deprecated>=1.2.6 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api>=1.2.0->chromadb) (1.2.14)\n",
101
+ "Requirement already satisfied: importlib-metadata<=7.0,>=6.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api>=1.2.0->chromadb) (7.0.0)\n",
102
+ "Requirement already satisfied: googleapis-common-protos~=1.52 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.63.0)\n",
103
+ "Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.24.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.24.0)\n",
104
+ "Requirement already satisfied: opentelemetry-proto==1.24.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.24.0)\n",
105
+ "Requirement already satisfied: opentelemetry-instrumentation-asgi==0.45b0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
106
+ "Requirement already satisfied: opentelemetry-instrumentation==0.45b0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
107
+ "Requirement already satisfied: opentelemetry-semantic-conventions==0.45b0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
108
+ "Requirement already satisfied: opentelemetry-util-http==0.45b0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
109
+ "Requirement already satisfied: setuptools>=16.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation==0.45b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (67.7.2)\n",
110
+ "Requirement already satisfied: wrapt<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation==0.45b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (1.14.1)\n",
111
+ "Requirement already satisfied: asgiref~=3.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-instrumentation-asgi==0.45b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (3.8.1)\n",
112
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.5.2->faknow) (2023.4)\n",
113
+ "Requirement already satisfied: monotonic>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb) (1.6)\n",
114
+ "Requirement already satisfied: backoff>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb) (2.2.1)\n",
115
+ "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb) (0.6.0)\n",
116
+ "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb) (2.16.3)\n",
117
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.28->chromadb) (3.3.2)\n",
118
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.28->chromadb) (3.6)\n",
119
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.1.3->faknow) (3.4.0)\n",
120
+ "Requirement already satisfied: markdown>=3.4 in /usr/local/lib/python3.10/dist-packages (from sphinx-markdown-tables>=0.0.17->faknow) (3.6)\n",
121
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.10.0->faknow) (1.4.0)\n",
122
+ "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.10.0->faknow) (1.2.0)\n",
123
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.10.0->faknow) (0.7.2)\n",
124
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.10.0->faknow) (3.0.1)\n",
125
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (3.2.1)\n",
126
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (3.1.3)\n",
127
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n",
128
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n",
129
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n",
130
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (8.9.2.26)\n",
131
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.3.1)\n",
132
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (11.0.2.54)\n",
133
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (10.3.2.106)\n",
134
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (11.4.5.107)\n",
135
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.0.106)\n",
136
+ "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (2.19.3)\n",
137
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (12.1.105)\n",
138
+ "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence-transformers) (2.2.0)\n",
139
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.11.0->sentence-transformers) (12.4.99)\n",
140
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.26.1->faknow) (0.4.2)\n",
141
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.14.0)\n",
142
+ "Requirement already satisfied: httptools>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.6.1)\n",
143
+ "Requirement already satisfied: python-dotenv>=0.13 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (1.0.1)\n",
144
+ "Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.19.0)\n",
145
+ "Requirement already satisfied: watchfiles>=0.13 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.21.0)\n",
146
+ "Requirement already satisfied: websockets>=10.4 in /usr/local/lib/python3.10/dist-packages (from uvicorn[standard]>=0.18.3->chromadb) (12.0)\n",
147
+ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (5.3.3)\n",
148
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (0.4.0)\n",
149
+ "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (4.9)\n",
150
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata<=7.0,>=6.0->opentelemetry-api>=1.2.0->chromadb) (3.18.1)\n",
151
+ "Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.37.0,>=0.36.3->fastapi>=0.95.2->chromadb) (3.7.1)\n",
152
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard>=2.10.0->faknow) (2.1.5)\n",
153
+ "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime>=1.14.1->chromadb) (10.0)\n",
154
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime>=1.14.1->chromadb) (1.3.0)\n",
155
+ "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.37.0,>=0.36.3->fastapi>=0.95.2->chromadb) (1.3.1)\n",
156
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.37.0,>=0.36.3->fastapi>=0.95.2->chromadb) (1.2.0)\n",
157
+ "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (0.6.0)\n"
158
+ ]
159
+ }
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "source": [],
165
+ "metadata": {
166
+ "id": "kG2sAMShgAOV"
167
+ },
168
+ "execution_count": 27,
169
+ "outputs": []
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "source": [
174
+ "import pandas as pd\n",
175
+ "import os\n",
176
+ "import chromadb\n",
177
+ "from chromadb.utils import embedding_functions\n",
178
+ "import math\n",
179
+ "\n",
180
+ "\n",
181
+ "\n",
182
+ "\n",
183
+ "\n",
184
+ "def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None:\n",
185
+ " \"\"\"This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance\n",
186
+ " 1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns\n",
187
+ "\n",
188
+ " Args:\n",
189
+ " collection_name (str) : name of database collection\n",
190
+ " vdb_path (str): Relative path of the location of the ChromaDB instance.\n",
191
+ " df (pd.DataFrame): task scheduling dataset.\n",
192
+ "\n",
193
+ " \"\"\"\n",
194
+ "\n",
195
+ " #identify the saving location of the ChromaDB\n",
196
+ " chroma_client = chromadb.PersistentClient(path=vdb_path)\n",
197
+ "\n",
198
+ " #extract the embedding from hugging face\n",
199
+ " embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=\"sentence-transformers/LaBSE\")\n",
200
+ "\n",
201
+ " #creating the collection\n",
202
+ " domain_identification_collection = chroma_client.create_collection(\n",
203
+ " name=collection_name,\n",
204
+ " embedding_function=embedding_function,\n",
205
+ " )\n",
206
+ "\n",
207
+ "\n",
208
+ " # the main text \"query\" that will be embedded\n",
209
+ " domain_identification_documents = [row.query for row in df.itertuples()]\n",
210
+ "\n",
211
+ " # the metadata\n",
212
+ " domain_identification_metadata = [\n",
213
+ " {\"domain\": row.domain , \"label\": row.label}\n",
214
+ " for row in df.itertuples()\n",
215
+ " ]\n",
216
+ "\n",
217
+ " #index\n",
218
+ " domain_ids = [\"domain_id \" + str(row.Index) for row in df.itertuples()]\n",
219
+ "\n",
220
+ "\n",
221
+ " length = len(df)\n",
222
+ " num_iteration = length / 166\n",
223
+ " num_iteration = math.ceil(num_iteration)\n",
224
+ "\n",
225
+ " start = 0\n",
226
+ " # start adding the the vectors\n",
227
+ " for i in range(num_iteration):\n",
228
+ " if i == num_iteration - 1 :\n",
229
+ " domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:])\n",
230
+ " else:\n",
231
+ " end = start + 166\n",
232
+ " domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end])\n",
233
+ " start = end\n",
234
+ " return None\n",
235
+ "\n",
236
+ "\n",
237
+ "\n",
238
+ "def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None:\n",
239
+ " \"\"\"Deletes a particular collection from the persistent ChromaDB instance.\n",
240
+ "\n",
241
+ " Args:\n",
242
+ " vdb_path (str): Path of the persistent ChromaDB instance.\n",
243
+ " collection_name (str): Name of the collection to be deleted.\n",
244
+ " \"\"\"\n",
245
+ " chroma_client = chromadb.PersistentClient(path=vdb_path)\n",
246
+ " chroma_client.delete_collection(collection_name)\n",
247
+ " return None\n",
248
+ "\n",
249
+ "\n",
250
+ "def list_collections_from_vector_db(vdb_path: str) -> None:\n",
251
+ " \"\"\"Lists all the available collections from the persistent ChromaDB instance.\n",
252
+ "\n",
253
+ " Args:\n",
254
+ " vdb_path (str): Path of the persistent ChromaDB instance.\n",
255
+ " \"\"\"\n",
256
+ " chroma_client = chromadb.PersistentClient(path=vdb_path)\n",
257
+ " print(chroma_client.list_collections())\n",
258
+ "\n",
259
+ "\n",
260
+ "def get_collection_from_vector_db(\n",
261
+ " vdb_path: str, collection_name: str\n",
262
+ ") -> chromadb.Collection:\n",
263
+ " \"\"\"Fetches a particular ChromaDB collection object from the persistent ChromaDB instance.\n",
264
+ "\n",
265
+ " Args:\n",
266
+ " vdb_path (str): Path of the persistent ChromaDB instance.\n",
267
+ " collection_name (str): Name of the collection which needs to be retrieved.\n",
268
+ " \"\"\"\n",
269
+ " chroma_client = chromadb.PersistentClient(path=vdb_path)\n",
270
+ "\n",
271
+ " huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=\"sentence-transformers/LaBSE\")\n",
272
+ "\n",
273
+ "\n",
274
+ "\n",
275
+ "\n",
276
+ " collection = chroma_client.get_collection(\n",
277
+ " name=collection_name, embedding_function=huggingface_ef\n",
278
+ " )\n",
279
+ "\n",
280
+ " return collection\n",
281
+ "\n",
282
+ "\n",
283
+ "def retrieval( input_text : str,\n",
284
+ " num_results : int,\n",
285
+ " collection: chromadb.Collection ):\n",
286
+ "\n",
287
+ " \"\"\"fetches the domain name from the collection based on the semantic similarity\n",
288
+ "\n",
289
+ " args:\n",
290
+ " input_text : the received text which can be news , posts , or tweets\n",
291
+ " num_results : number of fetched examples from the collection\n",
292
+ " collection : the extracted collection from the database that we will fetch examples from\n",
293
+ "\n",
294
+ " \"\"\"\n",
295
+ "\n",
296
+ "\n",
297
+ " fetched_domain = collection.query(\n",
298
+ " query_texts = [input_text],\n",
299
+ " n_results = num_results,\n",
300
+ " )\n",
301
+ "\n",
302
+ " #extracting domain name and label from the featched domains\n",
303
+ "\n",
304
+ " domain = fetched_domain[\"metadatas\"][0][0][\"domain\"]\n",
305
+ " label = fetched_domain[\"metadatas\"][0][0][\"label\"]\n",
306
+ " distance = fetched_domain[\"distances\"][0][0]\n",
307
+ "\n",
308
+ " return domain , label , distance"
309
+ ],
310
+ "metadata": {
311
+ "id": "-_UqusZqgAQP"
312
+ },
313
+ "execution_count": 28,
314
+ "outputs": []
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "source": [
319
+ "from transformers import pipeline\n",
320
+ "\n",
321
+ "\n",
322
+ "\n",
323
+ "\n",
324
+ "\n",
325
+ "\n",
326
+ "def english_information_extraction(text: str):\n",
327
+ "\n",
328
+ "\n",
329
+ "\n",
330
+ "\n",
331
+ " zeroshot_classifier = pipeline(\"zero-shot-classification\", model=\"MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33\")\n",
332
+ "\n",
333
+ " hypothesis_template_domain = \"This text is about {}\"\n",
334
+ " domain_classes = [\"women\" , \"muslims\" , \"tamil\" , \"sinhala\" , \"other\"]\n",
335
+ " domains_output= zeroshot_classifier(text, domain_classes , hypothesis_template=hypothesis_template_domain, multi_label=False)\n",
336
+ "\n",
337
+ " sentiment_discrimination_prompt = f\"the content of this text about {domains_output['labels'][0]} \"\n",
338
+ " hypothesis_template_sentiment = \"is {} sentiment\"\n",
339
+ " hypothesis_template_sentiment = sentiment_discrimination_prompt + hypothesis_template_sentiment\n",
340
+ "\n",
341
+ " sentiment_classes = [\"positive\" ,\"neutral\", \"negative\"]\n",
342
+ " sentiment_output= zeroshot_classifier(text, sentiment_classes , hypothesis_template=hypothesis_template_sentiment, multi_label=False)\n",
343
+ "\n",
344
+ " hypothesis_template_discrimination = \"is {}\"\n",
345
+ " hypothesis_template_discrimination = sentiment_discrimination_prompt + hypothesis_template_discrimination\n",
346
+ "\n",
347
+ " discrimination_classes = [\"hateful\" , \"not hateful\"]\n",
348
+ "\n",
349
+ " discrimination_output= zeroshot_classifier(text, discrimination_classes , hypothesis_template=hypothesis_template_discrimination, multi_label=False)\n",
350
+ "\n",
351
+ " domain_label , domain_score = domains_output[\"labels\"][0] , domains_output[\"scores\"][0]\n",
352
+ " sentiment_label , sentiment_score = sentiment_output[\"labels\"][0] , sentiment_output[\"scores\"][0]\n",
353
+ " discrimination_label , discrimination_score = discrimination_output[\"labels\"][0] , discrimination_output[\"scores\"][0]\n",
354
+ "\n",
355
+ " return {\"domain_label\" : domain_label,\n",
356
+ " \"domain_score\" : domain_score,\n",
357
+ " \"sentiment_label\" : sentiment_label,\n",
358
+ " \"sentiment_score\" : sentiment_score,\n",
359
+ " \"discrimination_label\" : discrimination_label,\n",
360
+ " \"discrimination_score\": discrimination_score}\n",
361
+ "\n",
362
+ "\n",
363
+ "\n"
364
+ ],
365
+ "metadata": {
366
+ "id": "G9EL047MfDDY"
367
+ },
368
+ "execution_count": 29,
369
+ "outputs": []
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "source": [],
374
+ "metadata": {
375
+ "id": "jmzyvmLQgASa"
376
+ },
377
+ "execution_count": 29,
378
+ "outputs": []
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "source": [
383
+ "#the model\n",
384
+ "from typing import List, Optional, Tuple\n",
385
+ "\n",
386
+ "import torch\n",
387
+ "from torch import Tensor\n",
388
+ "from torch import nn\n",
389
+ "from transformers import RobertaModel\n",
390
+ "\n",
391
+ "from faknow.model.layers.layer import TextCNNLayer\n",
392
+ "from faknow.model.model import AbstractModel\n",
393
+ "from faknow.data.process.text_process import TokenizerFromPreTrained\n",
394
+ "import pandas as pd\n",
395
+ "import gdown\n",
396
+ "import os\n",
397
+ "\n",
398
+ "class _MLP(nn.Module):\n",
399
+ " def __init__(self,\n",
400
+ " input_dim: int,\n",
401
+ " embed_dims: List[int],\n",
402
+ " dropout_rate: float,\n",
403
+ " output_layer=True):\n",
404
+ " super().__init__()\n",
405
+ " layers = list()\n",
406
+ " for embed_dim in embed_dims:\n",
407
+ " layers.append(nn.Linear(input_dim, embed_dim))\n",
408
+ " layers.append(nn.BatchNorm1d(embed_dim))\n",
409
+ " layers.append(nn.ReLU())\n",
410
+ " layers.append(nn.Dropout(p=dropout_rate))\n",
411
+ " input_dim = embed_dim\n",
412
+ " if output_layer:\n",
413
+ " layers.append(torch.nn.Linear(input_dim, 1))\n",
414
+ " self.mlp = torch.nn.Sequential(*layers)\n",
415
+ "\n",
416
+ " def forward(self, x: Tensor) -> Tensor:\n",
417
+ " \"\"\"\n",
418
+ "\n",
419
+ " Args:\n",
420
+ " x (Tensor): shared feature from domain and text, shape=(batch_size, embed_dim)\n",
421
+ "\n",
422
+ " \"\"\"\n",
423
+ " return self.mlp(x)\n",
424
+ "\n",
425
+ "\n",
426
+ "class _MaskAttentionLayer(torch.nn.Module):\n",
427
+ " \"\"\"\n",
428
+ " Compute attention layer\n",
429
+ " \"\"\"\n",
430
+ " def __init__(self, input_size: int):\n",
431
+ " super(_MaskAttentionLayer, self).__init__()\n",
432
+ " self.attention_layer = torch.nn.Linear(input_size, 1)\n",
433
+ "\n",
434
+ " def forward(self,\n",
435
+ " inputs: Tensor,\n",
436
+ " mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:\n",
437
+ " weights = self.attention_layer(inputs).view(-1, inputs.size(1))\n",
438
+ " if mask is not None:\n",
439
+ " weights = weights.masked_fill(mask == 0, float(\"-inf\"))\n",
440
+ " weights = torch.softmax(weights, dim=-1).unsqueeze(1)\n",
441
+ " outputs = torch.matmul(weights, inputs).squeeze(1)\n",
442
+ " return outputs, weights\n",
443
+ "\n",
444
+ "\n",
445
+ "class MDFEND(AbstractModel):\n",
446
+ " r\"\"\"\n",
447
+ " MDFEND: Multi-domain Fake News Detection, CIKM 2021\n",
448
+ " paper: https://dl.acm.org/doi/10.1145/3459637.3482139\n",
449
+ " code: https://github.com/kennqiang/MDFEND-Weibo21\n",
450
+ " \"\"\"\n",
451
+ " def __init__(self,\n",
452
+ " pre_trained_bert_name: str,\n",
453
+ " domain_num: int,\n",
454
+ " mlp_dims: Optional[List[int]] = None,\n",
455
+ " dropout_rate=0.2,\n",
456
+ " expert_num=5):\n",
457
+ " \"\"\"\n",
458
+ "\n",
459
+ " Args:\n",
460
+ " pre_trained_bert_name (str): the name or local path of pre-trained bert model\n",
461
+ " domain_num (int): total number of all domains\n",
462
+ " mlp_dims (List[int]): a list of the dimensions in MLP layer, if None, [384] will be taken as default, default=384\n",
463
+ " dropout_rate (float): rate of Dropout layer, default=0.2\n",
464
+ " expert_num (int): number of experts also called TextCNNLayer, default=5\n",
465
+ " \"\"\"\n",
466
+ " super(MDFEND, self).__init__()\n",
467
+ " self.domain_num = domain_num\n",
468
+ " self.expert_num = expert_num\n",
469
+ " self.bert = RobertaModel.from_pretrained(\n",
470
+ " pre_trained_bert_name).requires_grad_(False)\n",
471
+ " self.embedding_size = self.bert.config.hidden_size\n",
472
+ " self.loss_func = nn.BCELoss()\n",
473
+ " if mlp_dims is None:\n",
474
+ " mlp_dims = [384]\n",
475
+ "\n",
476
+ " filter_num = 64\n",
477
+ " filter_sizes = [1, 2, 3, 5, 10]\n",
478
+ " experts = [\n",
479
+ " TextCNNLayer(self.embedding_size, filter_num, filter_sizes)\n",
480
+ " for _ in range(self.expert_num)\n",
481
+ " ]\n",
482
+ " self.experts = nn.ModuleList(experts)\n",
483
+ "\n",
484
+ " self.gate = nn.Sequential(\n",
485
+ " nn.Linear(self.embedding_size * 2, mlp_dims[-1]), nn.ReLU(),\n",
486
+ " nn.Linear(mlp_dims[-1], self.expert_num), nn.Softmax(dim=1))\n",
487
+ "\n",
488
+ " self.attention = _MaskAttentionLayer(self.embedding_size)\n",
489
+ "\n",
490
+ " self.domain_embedder = nn.Embedding(num_embeddings=self.domain_num,\n",
491
+ " embedding_dim=self.embedding_size)\n",
492
+ " self.classifier = _MLP(320, mlp_dims, dropout_rate)\n",
493
+ "\n",
494
+ " def forward(self, token_id: Tensor, mask: Tensor,\n",
495
+ " domain: Tensor) -> Tensor:\n",
496
+ " \"\"\"\n",
497
+ "\n",
498
+ " Args:\n",
499
+ " token_id (Tensor): token ids from bert tokenizer, shape=(batch_size, max_len)\n",
500
+ " mask (Tensor): mask from bert tokenizer, shape=(batch_size, max_len)\n",
501
+ " domain (Tensor): domain id, shape=(batch_size,)\n",
502
+ "\n",
503
+ " Returns:\n",
504
+ " FloatTensor: the prediction of being fake, shape=(batch_size,)\n",
505
+ " \"\"\"\n",
506
+ " text_embedding = self.bert(token_id,\n",
507
+ " attention_mask=mask).last_hidden_state\n",
508
+ " attention_feature, _ = self.attention(text_embedding, mask)\n",
509
+ "\n",
510
+ " domain_embedding = self.domain_embedder(domain.view(-1, 1)).squeeze(1)\n",
511
+ "\n",
512
+ " gate_input = torch.cat([domain_embedding, attention_feature], dim=-1)\n",
513
+ " gate_output = self.gate(gate_input)\n",
514
+ "\n",
515
+ " shared_feature = 0\n",
516
+ " for i in range(self.expert_num):\n",
517
+ " expert_feature = self.experts[i](text_embedding)\n",
518
+ " shared_feature += (expert_feature * gate_output[:, i].unsqueeze(1))\n",
519
+ "\n",
520
+ " label_pred = self.classifier(shared_feature)\n",
521
+ "\n",
522
+ " return torch.sigmoid(label_pred.squeeze(1))\n",
523
+ "\n",
524
+ " def calculate_loss(self, data) -> Tensor:\n",
525
+ " \"\"\"\n",
526
+ " calculate loss via BCELoss\n",
527
+ "\n",
528
+ " Args:\n",
529
+ " data (dict): batch data dict\n",
530
+ "\n",
531
+ " Returns:\n",
532
+ " loss (Tensor): loss value\n",
533
+ " \"\"\"\n",
534
+ "\n",
535
+ " token_ids = data['text']['token_id']\n",
536
+ " masks = data['text']['mask']\n",
537
+ " domains = data['domain']\n",
538
+ " labels = data['label']\n",
539
+ " output = self.forward(token_ids, masks, domains)\n",
540
+ " return self.loss_func(output, labels.float())\n",
541
+ "\n",
542
+ " def predict(self, data_without_label) -> Tensor:\n",
543
+ " \"\"\"\n",
544
+ " predict the probability of being fake news\n",
545
+ "\n",
546
+ " Args:\n",
547
+ " data_without_label (Dict[str, Any]): batch data dict\n",
548
+ "\n",
549
+ " Returns:\n",
550
+ " Tensor: one-hot probability, shape=(batch_size, 2)\n",
551
+ " \"\"\"\n",
552
+ "\n",
553
+ " token_ids = data_without_label['text']['token_id']\n",
554
+ " masks = data_without_label['text']['mask']\n",
555
+ " domains = data_without_label['domain']\n",
556
+ "\n",
557
+ " # shape=(n,), data = 1 or 0\n",
558
+ " round_pred = torch.round(self.forward(token_ids, masks,\n",
559
+ " domains)).long()\n",
560
+ " # after one hot: shape=(n,2), data = [0,1] or [1,0]\n",
561
+ " one_hot_pred = torch.nn.functional.one_hot(round_pred, num_classes=2)\n",
562
+ " return one_hot_pred\n",
563
+ "\n",
564
+ "\n",
565
+ "def download_from_gdrive(file_id, output_path):\n",
566
+ " output = os.path.join(output_path)\n",
567
+ "\n",
568
+ " # Check if the file already exists\n",
569
+ " if not os.path.exists(output):\n",
570
+ " gdown.download(id=file_id, output=output, quiet=False)\n",
571
+ "\n",
572
+ "\n",
573
+ " return output\n",
574
+ "\n",
575
+ "\n",
576
+ "\n",
577
+ "def loading_model_and_tokenizer():\n",
578
+ " max_len, bert = 160, 'FacebookAI/xlm-roberta-base'\n",
579
+ " #https://drive.google.com/file/d/1--6GB3Ff81sILwtuvVTuAW3shGW_5VWC/view\n",
580
+ "\n",
581
+ " file_id = \"1--6GB3Ff81sILwtuvVTuAW3shGW_5VWC\"\n",
582
+ "\n",
583
+ " model_path = '/content/drive/MyDrive/models/last-epoch-model-2024-03-17-01_00_32_1.pth'\n",
584
+ "\n",
585
+ " MODEL_SAVE_PATH = download_from_gdrive(file_id, model_path)\n",
586
+ " domain_num = 4\n",
587
+ "\n",
588
+ "\n",
589
+ "\n",
590
+ " tokenizer = TokenizerFromPreTrained(max_len, bert)\n",
591
+ "\n",
592
+ " model = MDFEND(bert, domain_num , expert_num=12 , mlp_dims = [3010, 2024 ,1012 ,606 , 400])\n",
593
+ "\n",
594
+ " model.load_state_dict(torch.load(f=MODEL_SAVE_PATH , map_location=torch.device('cpu')))\n",
595
+ "\n",
596
+ " model.requires_grad_(False)\n",
597
+ "\n",
598
+ " return tokenizer , model"
599
+ ],
600
+ "metadata": {
601
+ "id": "A4zYbG-AmxQd"
602
+ },
603
+ "execution_count": 51,
604
+ "outputs": []
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "source": [
609
+ "import pandas as pd\n",
610
+ "import torch\n",
611
+ "def preparing_data(text:str , domain: int):\n",
612
+ " \"\"\"\n",
613
+ "\n",
614
+ "\n",
615
+ "\n",
616
+ " Args:\n",
617
+ " text (_str_): input text from the user\n",
618
+ " domain (_int_): output domain from domain identification pipeline\n",
619
+ "\n",
620
+ " Returns:\n",
621
+ " _DataFrame_: dataframe contains texts and domain\n",
622
+ " \"\"\"\n",
623
+ " # Let's assume you have the following dictionary\n",
624
+ " # the model can't do inference with only one example so this dummy example must be put\n",
625
+ " dict_data = {\n",
626
+ " 'text': ['hello world' ] ,\n",
627
+ " 'domain': [0] ,\n",
628
+ " }\n",
629
+ "\n",
630
+ " dict_data[\"text\"].append(text)\n",
631
+ " dict_data[\"domain\"].append(domain)\n",
632
+ " # Convert the dictionary to a DataFrame\n",
633
+ " df = pd.DataFrame(dict_data)\n",
634
+ "\n",
635
+ " # return the dataframe\n",
636
+ " return df\n",
637
+ "\n",
638
+ "\n",
639
+ "def loading_data(tokenizer , df: pd.DataFrame ):\n",
640
+ " ids = []\n",
641
+ " masks = []\n",
642
+ " domain_list = []\n",
643
+ "\n",
644
+ " texts = df[\"text\"]\n",
645
+ " domains= df[\"domain\"]\n",
646
+ "\n",
647
+ "\n",
648
+ " for i in range(len(df)):\n",
649
+ " text = texts[i]\n",
650
+ " token = tokenizer(text)\n",
651
+ " ids.append(token[\"token_id\"])\n",
652
+ " masks.append(token[\"mask\"])\n",
653
+ " domain_list.append(domains[i])\n",
654
+ "\n",
655
+ " input_ids = torch.cat(ids , dim=0)\n",
656
+ " input_masks = torch.cat(masks ,dim = 0)\n",
657
+ " input_domains = torch.tensor(domain_list)\n",
658
+ "\n",
659
+ "\n",
660
+ " return input_ids , input_masks , input_domains"
661
+ ],
662
+ "metadata": {
663
+ "id": "63oO220bidnk"
664
+ },
665
+ "execution_count": 31,
666
+ "outputs": []
667
+ },
668
+ {
669
+ "cell_type": "code",
670
+ "source": [
671
+ "import torch\n",
672
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
673
+ "\n",
674
+ "def language_identification(texts):\n",
675
+ " text = [\n",
676
+ " texts,\n",
677
+ "\n",
678
+ " ]\n",
679
+ "\n",
680
+ " model_ckpt = \"papluca/xlm-roberta-base-language-detection\"\n",
681
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
682
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)\n",
683
+ "\n",
684
+ " inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
685
+ "\n",
686
+ " with torch.no_grad():\n",
687
+ " logits = model(**inputs).logits\n",
688
+ "\n",
689
+ " preds = torch.softmax(logits, dim=-1)\n",
690
+ "\n",
691
+ " # Map raw predictions to languages\n",
692
+ " id2lang = model.config.id2label\n",
693
+ " vals, idxs = torch.max(preds, dim=1)\n",
694
+ " lang_dict = {id2lang[k.item()]: v.item() for k, v in zip(idxs, vals)}\n",
695
+ " return lang_dict"
696
+ ],
697
+ "metadata": {
698
+ "id": "mBrwFI_wPxtU"
699
+ },
700
+ "execution_count": 32,
701
+ "outputs": []
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "source": [
706
+ "from google.colab import drive\n",
707
+ "drive.mount('/content/drive')"
708
+ ],
709
+ "metadata": {
710
+ "colab": {
711
+ "base_uri": "https://localhost:8080/"
712
+ },
713
+ "id": "yuFVY6cZidqI",
714
+ "outputId": "766ef226-ad9a-444c-eff8-d02923ff1b7d"
715
+ },
716
+ "execution_count": 33,
717
+ "outputs": [
718
+ {
719
+ "output_type": "stream",
720
+ "name": "stdout",
721
+ "text": [
722
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
723
+ ]
724
+ }
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "source": [
730
+ "\n",
731
+ "def run_pipeline(input_text:str):\n",
732
+ "\n",
733
+ " language_dict = language_identification(input_text)\n",
734
+ " language_code = next(iter(language_dict))\n",
735
+ "\n",
736
+ " if language_code == \"en\":\n",
737
+ "\n",
738
+ " output_english = english_information_extraction(input_text)\n",
739
+ "\n",
740
+ " return output_english\n",
741
+ "\n",
742
+ " else:\n",
743
+ "\n",
744
+ "\n",
745
+ " num_results = 1\n",
746
+ " path = \"/content/drive/MyDrive/general_domains/vector_database\"\n",
747
+ " collection_name = \"general_domains\"\n",
748
+ "\n",
749
+ "\n",
750
+ " collection = get_collection_from_vector_db(path , collection_name)\n",
751
+ "\n",
752
+ " domain , label_domain , distance = retrieval(input_text , num_results , collection )\n",
753
+ "\n",
754
+ " if distance >1.45:\n",
755
+ " domain = \"undetermined\"\n",
756
+ "\n",
757
+ " tokenizer , model = loading_model_and_tokenizer()\n",
758
+ "\n",
759
+ " df = preparing_data(input_text , label_domain)\n",
760
+ "\n",
761
+ " input_ids , input_masks , input_domains = loading_data(tokenizer , df )\n",
762
+ "\n",
763
+ " labels = []\n",
764
+ " outputs = []\n",
765
+ " with torch.no_grad():\n",
766
+ "\n",
767
+ " pred = model.forward(input_ids, input_masks , input_domains)\n",
768
+ " labels.append([])\n",
769
+ "\n",
770
+ " for output in pred:\n",
771
+ " number = output.item()\n",
772
+ " label = int(1) if number >= 0.5 else int(0)\n",
773
+ " labels[-1].append(label)\n",
774
+ " outputs.append(pred)\n",
775
+ "\n",
776
+ " discrimination_class = [\"discriminative\" if i == int(1) else \"not discriminative\" for i in labels[0]]\n",
777
+ "\n",
778
+ "\n",
779
+ " return { \"domain_label\" :domain ,\n",
780
+ " \"domain_score\":distance ,\n",
781
+ " \"discrimination_label\" : discrimination_class[-1],\n",
782
+ " \"discrimination_score\" : outputs[0][1:].item(),\n",
783
+ " }\n",
784
+ "\n",
785
+ "\n",
786
+ "\n",
787
+ "\n",
788
+ "\n",
789
+ "\n",
790
+ "\n",
791
+ "\n",
792
+ "\n",
793
+ "\n",
794
+ "\n",
795
+ "\n",
796
+ "\n",
797
+ "\n"
798
+ ],
799
+ "metadata": {
800
+ "id": "HlBJF4NQgAVy"
801
+ },
802
+ "execution_count": 34,
803
+ "outputs": []
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "source": [
808
+ "input_text_1 = input(\"input text:\")\n",
809
+ "\n",
810
+ "output_1 = run_pipeline( input_text_1)"
811
+ ],
812
+ "metadata": {
813
+ "colab": {
814
+ "base_uri": "https://localhost:8080/"
815
+ },
816
+ "id": "1BVBXyRDnDg4",
817
+ "outputId": "de9a4f3e-4ad4-4d03-8d51-b3df05daa685"
818
+ },
819
+ "execution_count": 35,
820
+ "outputs": [
821
+ {
822
+ "name": "stdout",
823
+ "output_type": "stream",
824
+ "text": [
825
+ "input text:muslims loves their prophet muhammed\n"
826
+ ]
827
+ }
828
+ ]
829
+ },
830
+ {
831
+ "cell_type": "code",
832
+ "source": [
833
+ "output_1"
834
+ ],
835
+ "metadata": {
836
+ "colab": {
837
+ "base_uri": "https://localhost:8080/"
838
+ },
839
+ "id": "TnnB40tEnIHI",
840
+ "outputId": "752bc3bd-93ac-46be-9d1a-308c6fc267ed"
841
+ },
842
+ "execution_count": 36,
843
+ "outputs": [
844
+ {
845
+ "output_type": "execute_result",
846
+ "data": {
847
+ "text/plain": [
848
+ "{'domain_label': 'muslims',\n",
849
+ " 'domain_score': 0.9989225268363953,\n",
850
+ " 'sentiment_label': 'positive',\n",
851
+ " 'sentiment_score': 0.9239600300788879,\n",
852
+ " 'discrimination_label': 'not hateful',\n",
853
+ " 'discrimination_score': 0.9917498826980591}"
854
+ ]
855
+ },
856
+ "metadata": {},
857
+ "execution_count": 36
858
+ }
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "source": [
864
+ "input_text_2 = input(\"input text:\")\n",
865
+ "\n",
866
+ "output_2 = run_pipeline( input_text_2)"
867
+ ],
868
+ "metadata": {
869
+ "id": "LBAvmrE1QxM3",
870
+ "colab": {
871
+ "base_uri": "https://localhost:8080/"
872
+ },
873
+ "outputId": "45056e2c-701c-40c0-9a04-36710cc1bdbd"
874
+ },
875
+ "execution_count": 54,
876
+ "outputs": [
877
+ {
878
+ "name": "stdout",
879
+ "output_type": "stream",
880
+ "text": [
881
+ "input text:මුස්ලිම්වරු ඔවුන්ගේ අනාගතවක්තෘ මුහම්මද්ට ආදරෙයි\n"
882
+ ]
883
+ },
884
+ {
885
+ "output_type": "stream",
886
+ "name": "stderr",
887
+ "text": [
888
+ "Downloading...\n",
889
+ "From (original): https://drive.google.com/uc?id=1--6GB3Ff81sILwtuvVTuAW3shGW_5VWC\n",
890
+ "From (redirected): https://drive.google.com/uc?id=1--6GB3Ff81sILwtuvVTuAW3shGW_5VWC&confirm=t&uuid=4bc00ac8-29e3-458b-a64d-c0f583a18df7\n",
891
+ "To: /content/drive/MyDrive/models/last-epoch-model-2024-03-17-01_00_32_1.pth\n",
892
+ "100%|██████████| 1.20G/1.20G [00:17<00:00, 69.1MB/s]\n",
893
+ "You are using a model of type xlm-roberta to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.\n"
894
+ ]
895
+ }
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "source": [
901
+ "output_2"
902
+ ],
903
+ "metadata": {
904
+ "colab": {
905
+ "base_uri": "https://localhost:8080/"
906
+ },
907
+ "id": "ienC5lZvYjcu",
908
+ "outputId": "25eb47ee-f219-4ce0-915b-5fd3acb54414"
909
+ },
910
+ "execution_count": 55,
911
+ "outputs": [
912
+ {
913
+ "output_type": "execute_result",
914
+ "data": {
915
+ "text/plain": [
916
+ "{'domain_label': 'muslims',\n",
917
+ " 'domain_score': 0.9477148933517974,\n",
918
+ " 'discrimination_label': 'not discriminative',\n",
919
+ " 'discrimination_score': 0.016480498015880585}"
920
+ ]
921
+ },
922
+ "metadata": {},
923
+ "execution_count": 55
924
+ }
925
+ ]
926
+ },
927
+ {
928
+ "cell_type": "code",
929
+ "source": [
930
+ "input_text_3 = input(\"input text:\")\n",
931
+ "\n",
932
+ "output_3 = run_pipeline( input_text_3)"
933
+ ],
934
+ "metadata": {
935
+ "colab": {
936
+ "base_uri": "https://localhost:8080/"
937
+ },
938
+ "id": "kCe3FS5lYyQ7",
939
+ "outputId": "5ec7d2fd-3aa9-4e35-b4bf-2d1db4777aba"
940
+ },
941
+ "execution_count": 56,
942
+ "outputs": [
943
+ {
944
+ "name": "stdout",
945
+ "output_type": "stream",
946
+ "text": [
947
+ "input text:முஸ்லீம்கள் தங்கள் தீர்க்கதரிசியை நேசிக்கிறார்கள்\n"
948
+ ]
949
+ },
950
+ {
951
+ "output_type": "stream",
952
+ "name": "stderr",
953
+ "text": [
954
+ "You are using a model of type xlm-roberta to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.\n"
955
+ ]
956
+ }
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "code",
961
+ "source": [
962
+ "output_3"
963
+ ],
964
+ "metadata": {
965
+ "colab": {
966
+ "base_uri": "https://localhost:8080/"
967
+ },
968
+ "id": "4gCBAROLaDNK",
969
+ "outputId": "dd50be33-030c-4ea4-d2ca-5cd513eb3f0b"
970
+ },
971
+ "execution_count": 57,
972
+ "outputs": [
973
+ {
974
+ "output_type": "execute_result",
975
+ "data": {
976
+ "text/plain": [
977
+ "{'domain_label': 'muslims',\n",
978
+ " 'domain_score': 0.9295339941122466,\n",
979
+ " 'discrimination_label': 'not discriminative',\n",
980
+ " 'discrimination_score': 0.011930261738598347}"
981
+ ]
982
+ },
983
+ "metadata": {},
984
+ "execution_count": 57
985
+ }
986
+ ]
987
+ }
988
+ ]
989
+ }