MARI-posa commited on
Commit
2753f31
·
1 Parent(s): 645ec55

Update stri.py

Browse files
Files changed (1) hide show
  1. stri.py +8 -5
stri.py CHANGED
@@ -17,10 +17,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
18
 
19
  # Загрузка датасета и аннотаций к книгам
20
- books = pd.read_csv('books_6000.csv')
21
  books.dropna(inplace=True)
22
 
23
- books = books[books['annotation'].apply(lambda x: len(x.split()) >= 10)]
24
  books.drop_duplicates(subset='title', keep='first', inplace=True)
25
  books = books.reset_index(drop=True)
26
 
@@ -39,7 +39,7 @@ for i in ['author', 'title', 'annotation']:
39
  annot = books['annotation']
40
 
41
  # Получение эмбеддингов аннотаций каждой книги в датасете
42
- max_len = 128
43
 
44
  # Определение запроса пользователя
45
  query = st.text_input("Введите запрос")
@@ -58,9 +58,11 @@ if st.button('Сгенерировать'):
58
  query_padded = torch.tensor(query_padded, dtype=torch.long)
59
  query_mask = torch.tensor(query_mask, dtype=torch.long)
60
 
61
- with torch.no_grad():
62
  query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
63
- query_embedding = query_embedding[0][:, 0, :]
 
 
64
 
65
  # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
66
  cosine_similarities = torch.nn.functional.cosine_similarity(
@@ -83,4 +85,5 @@ if st.button('Сгенерировать'):
83
  response = requests.get(image_url)
84
  image = Image.open(BytesIO(response.content))
85
  cols[0].image(image)
 
86
  cols[1].write("---")
 
17
  model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
18
 
19
  # Загрузка датасета и аннотаций к книгам
20
+ books = pd.read_csv('all+.csv')
21
  books.dropna(inplace=True)
22
 
23
+ books = books[books['annotation'].apply(lambda x: len(x.split()) >= 40)]
24
  books.drop_duplicates(subset='title', keep='first', inplace=True)
25
  books = books.reset_index(drop=True)
26
 
 
39
  annot = books['annotation']
40
 
41
  # Получение эмбеддингов аннотаций каждой книги в датасете
42
+ max_len = 256
43
 
44
  # Определение запроса пользователя
45
  query = st.text_input("Введите запрос")
 
58
  query_padded = torch.tensor(query_padded, dtype=torch.long)
59
  query_mask = torch.tensor(query_mask, dtype=torch.long)
60
 
61
+ with torch.inference_mode():
62
  query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
63
+ query_embedding = query_embedding[0][:,0,:]
64
+ query_embedding = torch.nn.functional.normalize(query_embedding)
65
+
66
 
67
  # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
68
  cosine_similarities = torch.nn.functional.cosine_similarity(
 
85
  response = requests.get(image_url)
86
  image = Image.open(BytesIO(response.content))
87
  cols[0].image(image)
88
+ cols[0].write(cosine_similarities[i]:.2f)
89
  cols[1].write("---")