Teery commited on
Commit
c1efda8
·
1 Parent(s): 653fd0e
Files changed (3) hide show
  1. app.py +45 -0
  2. movies_2.csv +0 -0
  3. requirements.txt +71 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sklearn.metrics.pairwise import pairwise_distances, cosine_similarity
3
+ from scipy.spatial import distance
4
+ import pandas as pd
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
10
+ model = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
11
+
12
+ films = pd.read_csv('Films_finder/movies_2.csv')
13
+ films['description'] = films['description'].astype(str)
14
+
15
+ def embed_bert_cls(text, model, tokenizer):
16
+ t = tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=1024)
17
+ with torch.no_grad():
18
+ model_output = model(**{k: v.to(model.device) for k, v in t.items()})
19
+ embeddings = model_output.last_hidden_state[:, 0, :]
20
+ embeddings = torch.nn.functional.normalize(embeddings)
21
+ return embeddings[0].cpu().numpy()
22
+ @st.cache_resource
23
+ def for_embeded_list(series: pd.Series) -> list:
24
+ return np.array([embed_bert_cls(i.replace('\xa0', ' '), model, tokenizer) for i in series])
25
+ embeded_list = for_embeded_list(films['description'])
26
+ text = st.text_input('Введите текст')
27
+ count_visible = st.number_input("Введите количество отображаемых элементов", 1, 10, step=1)
28
+ if text and count_visible:
29
+ embeded_text = embed_bert_cls(text, model, tokenizer).reshape(1,-1)
30
+ cossim = pairwise_distances(embeded_text, embeded_list)[0]
31
+ for i in range(count_visible):
32
+ col1, col2 = st.columns(2)
33
+ with col1:
34
+ st.header(films.iloc[cossim.argsort()].iloc[i][2])
35
+ st.write(films.iloc[cossim.argsort()].iloc[i][3].replace('\xa0', ' '))
36
+ st.write(f'Уверенность состовляет {cossim[i]}')
37
+ with col2:
38
+ st.image(films.iloc[cossim.argsort()].iloc[i][1])
39
+ st.header('Самый не подходящий запрос')
40
+ col3, col4 = st.columns(2)
41
+ with col3:
42
+ st.header(films.iloc[cossim.argsort()].iloc[-1][2])
43
+ st.write(films.iloc[cossim.argsort()].iloc[-1][3].replace('\xa0', ' '))
44
+ with col4:
45
+ st.image(films.iloc[cossim.argsort()].iloc[-1][1])
movies_2.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.5
2
+ aiosignal==1.3.1
3
+ altair==5.1.1
4
+ async-timeout==4.0.3
5
+ attrs==23.1.0
6
+ blinker==1.6.2
7
+ cachetools==5.3.1
8
+ certifi==2023.7.22
9
+ charset-normalizer==3.2.0
10
+ click==8.1.7
11
+ datasets==2.14.5
12
+ dill==0.3.7
13
+ filelock==3.12.4
14
+ frozenlist==1.4.0
15
+ fsspec==2023.6.0
16
+ gitdb==4.0.10
17
+ GitPython==3.1.37
18
+ huggingface-hub==0.17.3
19
+ idna==3.4
20
+ importlib-metadata==6.8.0
21
+ Jinja2==3.1.2
22
+ joblib==1.3.2
23
+ jsonschema==4.19.1
24
+ jsonschema-specifications==2023.7.1
25
+ markdown-it-py==3.0.0
26
+ MarkupSafe==2.1.3
27
+ mdurl==0.1.2
28
+ mpmath==1.3.0
29
+ multidict==6.0.4
30
+ multiprocess==0.70.15
31
+ networkx==3.1
32
+ nltk==3.8.1
33
+ numpy==1.26.0
34
+ packaging==23.1
35
+ pandas==2.1.1
36
+ Pillow==9.5.0
37
+ protobuf==4.24.3
38
+ pyarrow==13.0.0
39
+ pydeck==0.8.1b0
40
+ Pygments==2.16.1
41
+ python-dateutil==2.8.2
42
+ pytz==2023.3.post1
43
+ PyYAML==6.0.1
44
+ referencing==0.30.2
45
+ regex==2023.8.8
46
+ requests==2.31.0
47
+ rich==13.5.3
48
+ rpds-py==0.10.3
49
+ scikit-learn==1.3.1
50
+ scipy==1.11.3
51
+ six==1.16.0
52
+ smmap==5.0.1
53
+ streamlit==1.27.0
54
+ sympy==1.12
55
+ tenacity==8.2.3
56
+ threadpoolctl==3.2.0
57
+ tokenizers==0.13.3
58
+ toml==0.10.2
59
+ toolz==0.12.0
60
+ torch==2.0.1
61
+ tornado==6.3.3
62
+ tqdm==4.66.1
63
+ transformers==4.28.0
64
+ typing_extensions==4.8.0
65
+ tzdata==2023.3
66
+ tzlocal==5.0.1
67
+ urllib3==2.0.5
68
+ validators==0.22.0
69
+ xxhash==3.3.0
70
+ yarl==1.9.2
71
+ zipp==3.17.0