import numpy as np import pandas as pd import time from sentence_transformers import SentenceTransformer from redis.commands.search.field import VectorField from redis.commands.search.field import TextField from redis.commands.search.field import TagField from redis.commands.search.query import Query import redis from tqdm import tqdm import google.generativeai as palm import pandas as pd from langchain.chains import LLMChain from langchain.prompts import PromptTemplate import os import gradio as gr import io from langchain.llms import GooglePalm import pandas as pd #from yolopandas import pd from langchain.embeddings import GooglePalmEmbeddings from langchain.memory import ConversationBufferMemory from dotenv import load_dotenv load_dotenv() redis_conn = redis.Redis( host='redis-15860.c322.us-east-1-2.ec2.cloud.redislabs.com', port=15860, password='PVnvSZI5nISPsrxxhCHZF3pfZWI7YAIG') ''' df = pd.read_csv("coms3.csv") print(list(df)) print(df['item_keywords'].sample(2)) company_metadata = df.to_dict(orient='index') model = SentenceTransformer('sentence-transformers/all-distilroberta-v1') item_keywords = [company_metadata[i]['item_keywords'] for i in company_metadata.keys()] item_keywords_vectors = [] for sentence in tqdm(item_keywords): s = model.encode(sentence) item_keywords_vectors.append(s) print(company_metadata[0]) def load_vectors(client, company_metadata, vector_dict, vector_field_name): p = client.pipeline(transaction=False) for index in company_metadata.keys(): #hash key #print(index) #print(company_metadata[index]['company_l_id']) try: key=str('company:'+ str(index)+ ':' + company_metadata[index]['primary_key']) except: print(key) continue #hash values item_metadata = company_metadata[index] item_keywords_vector = vector_dict[index].astype(np.float32).tobytes() item_metadata[vector_field_name]=item_keywords_vector # HSET p.hset(key,mapping=item_metadata) p.execute() def create_flat_index (redis_conn,vector_field_name,number_of_vectors, vector_dimensions=512, distance_metric='L2'): redis_conn.ft().create_index([ VectorField(vector_field_name, "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }), TagField("company_l_id"), TextField("company_name"), TextField("item_keywords"), TagField("industry") ]) ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector' TEXT_EMBEDDING_DIMENSION=768 NUMBER_COMPANIES=1000 print ('Loading and Indexing + ' + str(NUMBER_COMPANIES) + 'companies') #flush all data redis_conn.flushall() #create flat index & load vectors create_flat_index(redis_conn, ITEM_KEYWORD_EMBEDDING_FIELD,NUMBER_COMPANIES,TEXT_EMBEDDING_DIMENSION,'COSINE') load_vectors(redis_conn,company_metadata,item_keywords_vectors,ITEM_KEYWORD_EMBEDDING_FIELD) ''' model = SentenceTransformer('sentence-transformers/all-distilroberta-v1') ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector' TEXT_EMBEDDING_DIMENSION=768 NUMBER_PRODUCTS=1000 prompt = PromptTemplate( input_variables=["company_description"], template='Create comma seperated company keywords to perform a query on a company dataset for this user input' ) template = """You are a chatbot. Be kind, detailed and nice. Present the given queried search result in a nice way as answer to the user input. dont ask questions back! just take the given context {chat_history} Human: {user_question} Chatbot: """ prompt = PromptTemplate( input_variables=["chat_history", "user_question"], template=template ) chat_history= "" def answer(user_question): llm = GooglePalm(temperature=0, google_api_key=os.environ['PALM']) chain = LLMChain(llm=llm, prompt=prompt) keywords = chain.run({'user_question':user_question, 'chat_history':chat_history}) topK=3 #vectorize the query query_vector = model.encode(keywords).astype(np.float32).tobytes() q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2) params_dict = {"vec_param": query_vector} #Execute the query results = redis_conn.ft().search(q, query_params = params_dict) full_result_string = '' for company in results.docs: full_result_string += company.id + ' ' + company.item_keywords + "\n\n\n" memory = ConversationBufferMemory(memory_key="chat_history") llm_chain = LLMChain( llm=llm, prompt=prompt, verbose=False, memory=memory, ) ans = llm_chain.predict(user_msg= f"{full_result_string} ---\n\n {user_question}") return ans demo = gr.Interface( fn=answer, inputs=["text"], outputs=["text"], title="Ask Sonity", ) demo.launch(share=True)