Spaces:
Paused
Paused
sudharshan106
commited on
Commit
•
f0f2cc2
1
Parent(s):
ee73e95
rprec
Browse files- data/dataarxivfinal.csv +0 -0
- embedding_model_comp.npz +3 -0
- final.py +150 -0
- requirements.txt +15 -0
data/dataarxivfinal.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embedding_model_comp.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc564d5cc9d3072bdf184398c81b8d37e08128a35d922cf3a86dfd42b147240b
|
3 |
+
size 25339666
|
final.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cohere
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from qdrant_client.http import models
|
6 |
+
# import warnings
|
7 |
+
# warnings.filterwarnings('ignore')
|
8 |
+
import qdrant_client
|
9 |
+
import easynmt
|
10 |
+
# from config import CONFIG
|
11 |
+
|
12 |
+
model_translation = easynmt.EasyNMT('m2m_100_418M')# mbart50_en2m
|
13 |
+
|
14 |
+
model_type = "small"
|
15 |
+
|
16 |
+
cohere_api_key = st.secrets["COHERE_API_KEY"]
|
17 |
+
QDRANT_URL = st.secrets["QDRANT_URL"]
|
18 |
+
QDRANT_API_KEY = st.secrets["QDRANT_API_KEY"]
|
19 |
+
|
20 |
+
ds = pd.read_csv('data/dataarxivfinal.csv')
|
21 |
+
print(ds.shape)
|
22 |
+
cohere_client = cohere.Client(api_key=cohere_api_key)
|
23 |
+
embeddings = np.load("embedding_model_comp.npz")['a']
|
24 |
+
collection_name = "my_collection"
|
25 |
+
distance = models.Distance.COSINE
|
26 |
+
|
27 |
+
client = qdrant_client.QdrantClient(
|
28 |
+
url= QDRANT_URL,
|
29 |
+
api_key=QDRANT_API_KEY,
|
30 |
+
)
|
31 |
+
|
32 |
+
# Create Qdrant collection and upload the Embeddings
|
33 |
+
button_for_upload = st.sidebar.button('Load')
|
34 |
+
if button_for_upload:
|
35 |
+
|
36 |
+
with st.spinner("Loading Models"):
|
37 |
+
collection_id = client.recreate_collection(collection_name = collection_name,
|
38 |
+
vectors_config= models.VectorParams(size=embeddings.shape[1], distance=distance))
|
39 |
+
|
40 |
+
|
41 |
+
vectors=[list(map(float, vector)) for vector in embeddings]
|
42 |
+
|
43 |
+
ids = []
|
44 |
+
for i, j in enumerate(embeddings):
|
45 |
+
ids.append(i)
|
46 |
+
|
47 |
+
client.upload_collection(
|
48 |
+
collection_name=collection_name,
|
49 |
+
ids=ids,
|
50 |
+
vectors=vectors,
|
51 |
+
batch_size=128
|
52 |
+
)
|
53 |
+
|
54 |
+
article_rec_type = st.sidebar.selectbox(
|
55 |
+
"Recommend article type by",
|
56 |
+
( "Article Name", "Article Content", "Article Translator", "Article Summarizer")
|
57 |
+
)
|
58 |
+
|
59 |
+
def article_summarizer():
|
60 |
+
col1, col2 = st.columns(2)
|
61 |
+
summarize_decision = st.button('Summarize')
|
62 |
+
|
63 |
+
with col1:
|
64 |
+
with st.expander("Input text"):
|
65 |
+
prompt = st.text_area("Paste the sentence that needs to be Summarized")
|
66 |
+
|
67 |
+
with col2:
|
68 |
+
with st.expander("Summarized texts"):
|
69 |
+
if summarize_decision:
|
70 |
+
response = cohere_client.generate(
|
71 |
+
model='xlarge',
|
72 |
+
prompt = prompt,
|
73 |
+
max_tokens=512,
|
74 |
+
temperature=0.6,
|
75 |
+
k=0,
|
76 |
+
p=1,
|
77 |
+
frequency_penalty=0,
|
78 |
+
presence_penalty=0,
|
79 |
+
stop_sequences=["--"],truncate="end"
|
80 |
+
)
|
81 |
+
|
82 |
+
summary = response.generations[0].text
|
83 |
+
st.write(summary)
|
84 |
+
|
85 |
+
language_dict = {"Tamil":"ta", "Nepali":"ne", "Indonesian":"id", "Thai":"th","Spanish":"es", "Russian":"ru", "Turkish":"tr", "French":"fr"}
|
86 |
+
def article_translator():
|
87 |
+
col1, col2 = st.columns(2)
|
88 |
+
|
89 |
+
language = st.sidebar.selectbox(
|
90 |
+
"Select Language",
|
91 |
+
( "Tamil", "Nepali", "Indonesian", "Thai","Spanish", "Russian", "Turkish", "French")
|
92 |
+
)
|
93 |
+
|
94 |
+
translate_decision = st.button('Translate')
|
95 |
+
with col1:
|
96 |
+
with st.expander("Input text"):
|
97 |
+
text = st.text_area("Paste the sentence that needs to be Translated")
|
98 |
+
|
99 |
+
with col2:
|
100 |
+
with st.expander("Translated text"):
|
101 |
+
if translate_decision:
|
102 |
+
result = model_translation.translate(text, target_lang=language_dict[language])
|
103 |
+
st.write(result)
|
104 |
+
|
105 |
+
|
106 |
+
def article_name():
|
107 |
+
title = st.selectbox('Article Name', options=tuple(ds['title'].values))
|
108 |
+
top_k = st.slider("Number of recommendations", 1, 10, step=1)
|
109 |
+
button = st.button('Predict')
|
110 |
+
|
111 |
+
if button:
|
112 |
+
|
113 |
+
query_to_ = ds[ds['title']==title].head(1)['abstract'].values[0]
|
114 |
+
query_vector = cohere_client.embed([query_to_], model=model_type, truncate="RIGHT").embeddings[0]
|
115 |
+
query_vector = list(map(float, query_vector))
|
116 |
+
search_result = client.search(collection_name=collection_name, query_vector=query_vector,limit=top_k)
|
117 |
+
similar_text_indices = [hit.id for hit in search_result]
|
118 |
+
|
119 |
+
score_ = [record.score for record in search_result]
|
120 |
+
|
121 |
+
for j,i in enumerate(ds.iloc[similar_text_indices].iterrows()):
|
122 |
+
st.write(f"**{i[1]['title']}** score:{score_[j]}")
|
123 |
+
|
124 |
+
def article_content():
|
125 |
+
search_decision = st.button('Search')
|
126 |
+
|
127 |
+
with st.expander("Input text"):
|
128 |
+
query_to_ = st.text_area("Paste the Contents that need to be searched for")
|
129 |
+
top_k = st.slider("Number of recommendations", 1, 10, step=1)
|
130 |
+
|
131 |
+
if search_decision:
|
132 |
+
query_vector = cohere_client.embed([query_to_], model=model_type, truncate="RIGHT").embeddings[0]
|
133 |
+
query_vector = list(map(float, query_vector))
|
134 |
+
search_result = client.search(collection_name=collection_name, query_vector=query_vector,limit=top_k)
|
135 |
+
similar_text_indices = [hit.id for hit in search_result]
|
136 |
+
|
137 |
+
score_ = [record.score for record in search_result]
|
138 |
+
|
139 |
+
for j,i in enumerate(ds.iloc[similar_text_indices].iterrows()):
|
140 |
+
st.write(f"**{i[1]['title']}** score:{score_[j]}")
|
141 |
+
|
142 |
+
|
143 |
+
if article_rec_type=='Article Name':
|
144 |
+
article_name()
|
145 |
+
elif article_rec_type == 'Article Translator':
|
146 |
+
article_translator()
|
147 |
+
elif article_rec_type == "Article Summarizer":
|
148 |
+
article_summarizer()
|
149 |
+
else:
|
150 |
+
article_content()
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cohere==3.10.0
|
2 |
+
EasyNMT==2.0.2
|
3 |
+
fasttext==0.9.2
|
4 |
+
nltk==3.8.1
|
5 |
+
numba==0.56.4
|
6 |
+
|
7 |
+
pandas==1.3.5
|
8 |
+
qdrant-client==1.0.5
|
9 |
+
regex==2022.10.31
|
10 |
+
sentencepiece==0.1.97
|
11 |
+
streamlit==1.20.0
|
12 |
+
tokenizers==0.13.2
|
13 |
+
torch==1.13.1
|
14 |
+
tqdm==4.65.0
|
15 |
+
transformers==4.27.1
|