A-Roucher commited on
Commit
0bff0fd
1 Parent(s): 63db6ac

feat: add requirements

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. requirements.txt +3 -0
app.py CHANGED
@@ -7,14 +7,15 @@ st.write(x, 'squared is', x * x)
7
 
8
  st.sidebar.text_input("Type your quote here")
9
 
10
- dataset = datasets.load_dataset('A-Roucher/english_historical_quotes')['train']
 
11
 
12
  model_name = "sentence-transformers/all-MiniLM-L6-v2" # BAAI/bge-small-en-v1.5" # "Cohere/Cohere-embed-english-light-v3.0" # "sentence-transformers/all-MiniLM-L6-v2"
13
 
14
  encoder = SentenceTransformer(model_name)
15
  embeddings = encoder.encode(
16
  dataset["quote"],
17
- batch_size=8,
18
  show_progress_bar=True,
19
  convert_to_numpy=True,
20
  normalize_embeddings=True,
@@ -38,9 +39,9 @@ sentence_embedding = encoder.encode([sentence])
38
  from sentence_transformers.util import semantic_search
39
 
40
  # hits = semantic_search(sentence_embedding, dataset_embeddings[:, :], top_k=5)
41
- author_indexes = range(1000)
42
  hits = semantic_search(sentence_embedding, dataset_embeddings[author_indexes, :], top_k=5)
43
-
44
  list_hits = [author_indexes[i['corpus_id']] for i in hits[0]]
45
  st.write(dataset_embeddings.select([12676, 4967, 2612, 8884, 4797]))
46
 
 
7
 
8
  st.sidebar.text_input("Type your quote here")
9
 
10
+ dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload")
11
+ dataset = dataset['train']
12
 
13
  model_name = "sentence-transformers/all-MiniLM-L6-v2" # BAAI/bge-small-en-v1.5" # "Cohere/Cohere-embed-english-light-v3.0" # "sentence-transformers/all-MiniLM-L6-v2"
14
 
15
  encoder = SentenceTransformer(model_name)
16
  embeddings = encoder.encode(
17
  dataset["quote"],
18
+ batch_size=4,
19
  show_progress_bar=True,
20
  convert_to_numpy=True,
21
  normalize_embeddings=True,
 
39
  from sentence_transformers.util import semantic_search
40
 
41
  # hits = semantic_search(sentence_embedding, dataset_embeddings[:, :], top_k=5)
42
+ author_indexes = list(range(1000))
43
  hits = semantic_search(sentence_embedding, dataset_embeddings[author_indexes, :], top_k=5)
44
+ st.write(hits)
45
  list_hits = [author_indexes[i['corpus_id']] for i in hits[0]]
46
  st.write(dataset_embeddings.select([12676, 4967, 2612, 8884, 4797]))
47
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets==2.5.2
2
+ sentence_transformers==2.2.2
3
+ streamlit==1.28.1