File size: 1,522 Bytes
ef048c4
28b86d4
541438b
9fcc182
08f492d
bcf5ebf
5162987
086dfb1
55bb617
9314300
 
086dfb1
 
 
2e9749b
 
086dfb1
55bb617
086dfb1
4baf634
3d3f363
086dfb1
b10dd8a
08412aa
3bd360f
b10dd8a
f587a83
 
9dc457e
529b072
3fef0e6
2a57255
 
 
 
 
 
 
 
 
086dfb1
8a3c87a
3fef0e6
 
c58105e
e0765b5
 
d68574d
e0765b5
 
9473952
4af53c3
50fae99
6812814
 
a9b3ec0
7cdebdb
b7b8dac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
import os
import requests
import pandas as pd
import boto3
import AWSHandler
import pinecone
from sentence_transformers import SentenceTransformer

aws_access_key = os.getenv("aws_access_key")
aws_secret_key = os.getenv("aws_secret_key")
pinecone_api_key = os.getenv("pinecone_api_key")
pinecone_environment = os.getenv("pinecone_environment")

s3 = boto3.client('s3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key)

model = SentenceTransformer('all-mpnet-base-v2')

st.set_page_config(layout="wide")



def search_index(query):
    pinecone.init(api_key=pinecone_api_key, environment=pinecone_environment)
    index = pinecone.Index("scotus")
    vector = model.encode([query]).tolist()
    print(index)
    print(index.query)
    print(index.query(vector=vector,top_k=10,include_values=False,include_metadata=True).to_dict())
    return index.query(vector=vector,top_k=10,include_values=False,include_metadata=True).to_dict()['matches']
    

st.markdown('''
<style>
[data-testid="stMarkdownContainer"] ul{
    list-style-position: inside;
}
</style>
''', unsafe_allow_html=True)


search = st.text_input('Search', placeholder='First Amendment cannot protect libel')
if search != "":
    res = search_index(query=search)
    for m in res:
        text =  m['metadata']['Text']
        if 'Name' in m['metadata']:
            name = m['metadata']['Name']
        else:
            name = ""
        st.markdown("- " + text + ' ' + '(_' + name + '_)')