from typing import Iterable import streamlit as st import torch from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Qdrant from qdrant_client import QdrantClient from qdrant_client.http.models import Filter, FieldCondition, MatchValue from config import DB_CONFIG @st.cache_resource def load_embeddings(): model_name = "intfloat/multilingual-e5-large" model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} encode_kwargs = {"normalize_embeddings": False} embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) return embeddings EMBEDDINGS = load_embeddings() def make_filter_obj(options: list[dict[str]]): must = [] for option in options: must.append( FieldCondition(key=option["key"], match=MatchValue(value=option["value"])) ) filter = Filter(must=must) return filter def get_similay(query: str, filter: Filter): db_url, db_api_key, db_collection_name = DB_CONFIG client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant( client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS ) docs = db.similarity_search_with_score( query, k=20, filter=filter, ) return docs def main( query: str, repo_name: str, query_options: str, ) -> Iterable[tuple[str, tuple[str, str]]]: options = [{"key": "metadata.repo_name", "value": repo_name}] if query_options == "Empty": query_options = "" query_str = f"{query_options}{query}" filter = make_filter_obj(options=options) docs = get_similay(query_str, filter) for doc, score in docs: text = doc.page_content metadata = doc.metadata # print(metadata) title = metadata.get("title") url = metadata.get("url") id_ = metadata.get("id") is_comment = metadata.get("type_") == "comment" yield title, url, id_, text, score, is_comment with st.form("my_form"): st.title("GitHub Issue Search") query = st.text_input(label="query") repo_name = st.radio( options=[ "cpython", "pyvista", "plone", "volto", "plone.restapi", "nvda", "nvdajp", "cocoa", ], label="Repo name", ) query_options = st.radio( options=[ "query: ", "query: passage: ", "Empty", ], label="Query options", ) submitted = st.form_submit_button("Submit") if submitted: st.divider() st.header("Search Results") st.divider() with st.spinner("Searching..."): results = main(query, repo_name, query_options) for title, url, id_, text, score, is_comment in results: with st.container(): if not is_comment: st.subheader(f"#{id_} - {title}") else: st.subheader(f"comment with {title}") st.write(url) st.write(text) st.write(score) # st.markdown(html, unsafe_allow_html=True) st.divider()