sudharshan106 commited on
Commit
f0f2cc2
1 Parent(s): ee73e95
data/dataarxivfinal.csv ADDED
The diff for this file is too large to render. See raw diff
 
embedding_model_comp.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc564d5cc9d3072bdf184398c81b8d37e08128a35d922cf3a86dfd42b147240b
3
+ size 25339666
final.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cohere
3
+ import numpy as np
4
+ import pandas as pd
5
+ from qdrant_client.http import models
6
+ # import warnings
7
+ # warnings.filterwarnings('ignore')
8
+ import qdrant_client
9
+ import easynmt
10
+ # from config import CONFIG
11
+
12
+ model_translation = easynmt.EasyNMT('m2m_100_418M')# mbart50_en2m
13
+
14
+ model_type = "small"
15
+
16
+ cohere_api_key = st.secrets["COHERE_API_KEY"]
17
+ QDRANT_URL = st.secrets["QDRANT_URL"]
18
+ QDRANT_API_KEY = st.secrets["QDRANT_API_KEY"]
19
+
20
+ ds = pd.read_csv('data/dataarxivfinal.csv')
21
+ print(ds.shape)
22
+ cohere_client = cohere.Client(api_key=cohere_api_key)
23
+ embeddings = np.load("embedding_model_comp.npz")['a']
24
+ collection_name = "my_collection"
25
+ distance = models.Distance.COSINE
26
+
27
+ client = qdrant_client.QdrantClient(
28
+ url= QDRANT_URL,
29
+ api_key=QDRANT_API_KEY,
30
+ )
31
+
32
+ # Create Qdrant collection and upload the Embeddings
33
+ button_for_upload = st.sidebar.button('Load')
34
+ if button_for_upload:
35
+
36
+ with st.spinner("Loading Models"):
37
+ collection_id = client.recreate_collection(collection_name = collection_name,
38
+ vectors_config= models.VectorParams(size=embeddings.shape[1], distance=distance))
39
+
40
+
41
+ vectors=[list(map(float, vector)) for vector in embeddings]
42
+
43
+ ids = []
44
+ for i, j in enumerate(embeddings):
45
+ ids.append(i)
46
+
47
+ client.upload_collection(
48
+ collection_name=collection_name,
49
+ ids=ids,
50
+ vectors=vectors,
51
+ batch_size=128
52
+ )
53
+
54
+ article_rec_type = st.sidebar.selectbox(
55
+ "Recommend article type by",
56
+ ( "Article Name", "Article Content", "Article Translator", "Article Summarizer")
57
+ )
58
+
59
+ def article_summarizer():
60
+ col1, col2 = st.columns(2)
61
+ summarize_decision = st.button('Summarize')
62
+
63
+ with col1:
64
+ with st.expander("Input text"):
65
+ prompt = st.text_area("Paste the sentence that needs to be Summarized")
66
+
67
+ with col2:
68
+ with st.expander("Summarized texts"):
69
+ if summarize_decision:
70
+ response = cohere_client.generate(
71
+ model='xlarge',
72
+ prompt = prompt,
73
+ max_tokens=512,
74
+ temperature=0.6,
75
+ k=0,
76
+ p=1,
77
+ frequency_penalty=0,
78
+ presence_penalty=0,
79
+ stop_sequences=["--"],truncate="end"
80
+ )
81
+
82
+ summary = response.generations[0].text
83
+ st.write(summary)
84
+
85
+ language_dict = {"Tamil":"ta", "Nepali":"ne", "Indonesian":"id", "Thai":"th","Spanish":"es", "Russian":"ru", "Turkish":"tr", "French":"fr"}
86
+ def article_translator():
87
+ col1, col2 = st.columns(2)
88
+
89
+ language = st.sidebar.selectbox(
90
+ "Select Language",
91
+ ( "Tamil", "Nepali", "Indonesian", "Thai","Spanish", "Russian", "Turkish", "French")
92
+ )
93
+
94
+ translate_decision = st.button('Translate')
95
+ with col1:
96
+ with st.expander("Input text"):
97
+ text = st.text_area("Paste the sentence that needs to be Translated")
98
+
99
+ with col2:
100
+ with st.expander("Translated text"):
101
+ if translate_decision:
102
+ result = model_translation.translate(text, target_lang=language_dict[language])
103
+ st.write(result)
104
+
105
+
106
+ def article_name():
107
+ title = st.selectbox('Article Name', options=tuple(ds['title'].values))
108
+ top_k = st.slider("Number of recommendations", 1, 10, step=1)
109
+ button = st.button('Predict')
110
+
111
+ if button:
112
+
113
+ query_to_ = ds[ds['title']==title].head(1)['abstract'].values[0]
114
+ query_vector = cohere_client.embed([query_to_], model=model_type, truncate="RIGHT").embeddings[0]
115
+ query_vector = list(map(float, query_vector))
116
+ search_result = client.search(collection_name=collection_name, query_vector=query_vector,limit=top_k)
117
+ similar_text_indices = [hit.id for hit in search_result]
118
+
119
+ score_ = [record.score for record in search_result]
120
+
121
+ for j,i in enumerate(ds.iloc[similar_text_indices].iterrows()):
122
+ st.write(f"**{i[1]['title']}** score:{score_[j]}")
123
+
124
+ def article_content():
125
+ search_decision = st.button('Search')
126
+
127
+ with st.expander("Input text"):
128
+ query_to_ = st.text_area("Paste the Contents that need to be searched for")
129
+ top_k = st.slider("Number of recommendations", 1, 10, step=1)
130
+
131
+ if search_decision:
132
+ query_vector = cohere_client.embed([query_to_], model=model_type, truncate="RIGHT").embeddings[0]
133
+ query_vector = list(map(float, query_vector))
134
+ search_result = client.search(collection_name=collection_name, query_vector=query_vector,limit=top_k)
135
+ similar_text_indices = [hit.id for hit in search_result]
136
+
137
+ score_ = [record.score for record in search_result]
138
+
139
+ for j,i in enumerate(ds.iloc[similar_text_indices].iterrows()):
140
+ st.write(f"**{i[1]['title']}** score:{score_[j]}")
141
+
142
+
143
+ if article_rec_type=='Article Name':
144
+ article_name()
145
+ elif article_rec_type == 'Article Translator':
146
+ article_translator()
147
+ elif article_rec_type == "Article Summarizer":
148
+ article_summarizer()
149
+ else:
150
+ article_content()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cohere==3.10.0
2
+ EasyNMT==2.0.2
3
+ fasttext==0.9.2
4
+ nltk==3.8.1
5
+ numba==0.56.4
6
+
7
+ pandas==1.3.5
8
+ qdrant-client==1.0.5
9
+ regex==2022.10.31
10
+ sentencepiece==0.1.97
11
+ streamlit==1.20.0
12
+ tokenizers==0.13.2
13
+ torch==1.13.1
14
+ tqdm==4.65.0
15
+ transformers==4.27.1