from time import time from datetime import datetime, date, timedelta from typing import Iterable import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Qdrant from qdrant_client import QdrantClient from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range from langchain.chains import RetrievalQA from openai.error import InvalidRequestError from langchain.chat_models import ChatOpenAI from config import DB_CONFIG from model import Issue @st.cache_resource def load_embeddings(): model_name = "intfloat/multilingual-e5-large" model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} encode_kwargs = {"normalize_embeddings": False} embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) return embeddings @st.cache_resource def llm_model(model="gpt-3.5-turbo", temperature=0.2): llm = ChatOpenAI(model=model, temperature=temperature) return llm @st.cache_resource def load_vicuna_model(): if torch.cuda.is_available(): model_name = "lmsys/vicuna-13b-v1.5" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", ) return tokenizer, model else: return None, None EMBEDDINGS = load_embeddings() LLM = llm_model() VICUNA_TOKENIZER, VICUNA_MODEL = load_vicuna_model() @st.cache_resource def _get_vicuna_llm(temperature=0.2) -> HuggingFacePipeline | None: if VICUNA_MODEL is not None: pipe = pipeline( "text-generation", model=VICUNA_MODEL, tokenizer=VICUNA_TOKENIZER, max_new_tokens=1024, temperature=temperature, ) llm = HuggingFacePipeline(pipeline=pipe) else: llm = None return llm VICUNA_LLM = _get_vicuna_llm() def make_filter_obj(options: list[dict[str]]): # print(options) must = [] for option in options: if "value" in option: must.append( FieldCondition( key=option["key"], match=MatchValue(value=option["value"]) ) ) elif "range" in option: range_ = option["range"] must.append( FieldCondition( key=option["key"], range=Range( gt=range_.get("gt"), gte=range_.get("gte"), lt=range_.get("lt"), lte=range_.get("lte"), ), ) ) filter = Filter(must=must) return filter def get_similay(query: str, filter: Filter): db_url, db_api_key, db_collection_name = DB_CONFIG client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant( client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS ) docs = db.similarity_search_with_score( query, k=20, filter=filter, ) return docs def get_retrieval_qa(filter: Filter, llm): db_url, db_api_key, db_collection_name = DB_CONFIG client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant( client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS ) retriever = db.as_retriever( search_kwargs={ "filter": filter, } ) result = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, ) return result def _get_related_url(metadata) -> Iterable[str]: urls = set() for m in metadata: url = m["url"] if url in urls: continue urls.add(url) created_at = datetime.fromtimestamp(m["created_at"]) # print(m) yield f'

URL: {url} (created: {created_at:%Y-%m-%d})

' def _get_query_str_filter( query: str, repo_name: str, query_options: str, start_date: date, end_date: date, include_comments: bool, ) -> tuple[str, Filter]: options = [{"key": "metadata.repo_name", "value": repo_name}] if start_date is not None and end_date is not None: options.append( { "key": "metadata.created_at", "range": { "gte": int(datetime.fromisoformat(str(start_date)).timestamp()), "lte": int( datetime.fromisoformat( str(end_date + timedelta(days=1)) ).timestamp() ), }, } ) if not include_comments: options.append({"key": "metadata.type_", "value": "issue"}) filter = make_filter_obj(options=options) if query_options == "Empty": query_options = "" query_str = f"{query_options}{query}" return query_str, filter def run_qa( llm, query: str, repo_name: str, query_options: str, start_date: date, end_date: date, include_comments: bool, ) -> tuple[str, str]: now = time() query_str, filter = _get_query_str_filter( query, repo_name, query_options, start_date, end_date, include_comments ) qa = get_retrieval_qa(filter, llm) try: result = qa(query_str) except InvalidRequestError as e: return "回答が見つかりませんでした。別な質問をしてみてください", str(e) else: metadata = [s.metadata for s in result["source_documents"]] sec_html = f"

実行時間: {(time() - now):.2f}秒

" html = "
" + sec_html + "\n".join(_get_related_url(metadata)) + "
" return result["result"], html def run_search( query: str, repo_name: str, query_options: str, start_date: date, end_date: date, include_comments: bool, ) -> Iterable[tuple[Issue, float, str]]: query_str, filter = _get_query_str_filter( query, repo_name, query_options, start_date, end_date, include_comments ) docs = get_similay(query_str, filter) for doc, score in docs: text = doc.page_content metadata = doc.metadata # print(metadata) issue = Issue( repo_name=repo_name, id=metadata.get("id"), title=metadata.get("title"), created_at=metadata.get("created_at"), user=metadata.get("user"), url=metadata.get("url"), labels=metadata.get("labels"), type_=metadata.get("type_"), ) yield issue, score, text with st.form("my_form"): st.title("GitHub Issue Search") query = st.text_input(label="query") repo_name = st.radio( options=[ "cpython", "pyvista", "plone", "volto", "plone.restapi", "nvda", "nvdajp", "cocoa", ], label="Repo name", ) query_options = st.radio( options=[ "query: ", "query: passage: ", "Empty", ], label="Query options", ) date_min = date(2022, 1, 1) date_max = date.today() date_col1, date_col2 = st.columns(2) start_date = date_col1.date_input( label="Select a start date", value=date_min, format="YYYY-MM-DD", ) end_date = date_col2.date_input( label="Select a end date", value=date_max, format="YYYY-MM-DD", ) include_comments = st.checkbox(label="Include Issue comments", value=True) submit_col1, submit_col2 = st.columns(2) searched = submit_col1.form_submit_button("Search") if searched: st.divider() st.header("Search Results") st.divider() with st.spinner("Searching..."): results = run_search( query, repo_name, query_options, start_date, end_date, include_comments ) for issue, score, text in results: title = issue.title url = issue.url id_ = issue.id score = round(score, 3) created_at = datetime.fromtimestamp(issue.created_at) user = issue.user labels = issue.labels is_comment = issue.type_ == "comment" with st.container(): if not is_comment: st.subheader(f"#{id_} - {title}") else: st.subheader(f"comment with {title}") st.write(url) st.write(text) st.write("score:", score, "Date:", created_at.date(), "User:", user) st.write(f"{labels=}") # st.markdown(html, unsafe_allow_html=True) st.divider() qa_searched = submit_col2.form_submit_button("QA Search by OpenAI") if qa_searched: st.divider() st.header("QA Search Results by OpenAI GPT-3") st.divider() with st.spinner("QA Searching..."): results = run_qa( LLM, query, repo_name, query_options, start_date, end_date, include_comments, ) answer, html = results with st.container(): st.write(answer) st.markdown(html, unsafe_allow_html=True) st.divider() if torch.cuda.is_available(): qa_searched_vicuna = submit_col2.form_submit_button("QA Search by Vicuna") if qa_searched_vicuna: st.divider() st.header("QA Search Results by Vicuna-13b-v1.5") st.divider() with st.spinner("QA Searching..."): results = run_qa( VICUNA_LLM, query, repo_name, query_options, start_date, end_date, include_comments, ) answer, html = results with st.container(): st.write(answer) st.markdown(html, unsafe_allow_html=True) st.divider()