sonitycom / demo.py
rairo's picture
Update demo.py
4001eb5
import numpy as np
import pandas as pd
import time
from sentence_transformers import SentenceTransformer
from redis.commands.search.field import VectorField
from redis.commands.search.field import TextField
from redis.commands.search.field import TagField
from redis.commands.search.query import Query
import redis
from tqdm import tqdm
import google.generativeai as palm
import pandas as pd
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import os
import gradio as gr
import io
from langchain.llms import GooglePalm
import pandas as pd
#from yolopandas import pd
from langchain.embeddings import GooglePalmEmbeddings
from langchain.memory import ConversationBufferMemory
from dotenv import load_dotenv
load_dotenv()
redis_conn = redis.Redis(
host='redis-15860.c322.us-east-1-2.ec2.cloud.redislabs.com',
port=15860,
password='PVnvSZI5nISPsrxxhCHZF3pfZWI7YAIG')
'''
df = pd.read_csv("coms3.csv")
print(list(df))
print(df['item_keywords'].sample(2))
company_metadata = df.to_dict(orient='index')
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
item_keywords = [company_metadata[i]['item_keywords'] for i in company_metadata.keys()]
item_keywords_vectors = []
for sentence in tqdm(item_keywords):
s = model.encode(sentence)
item_keywords_vectors.append(s)
print(company_metadata[0])
def load_vectors(client, company_metadata, vector_dict, vector_field_name):
p = client.pipeline(transaction=False)
for index in company_metadata.keys():
#hash key
#print(index)
#print(company_metadata[index]['company_l_id'])
try:
key=str('company:'+ str(index)+ ':' + company_metadata[index]['primary_key'])
except:
print(key)
continue
#hash values
item_metadata = company_metadata[index]
item_keywords_vector = vector_dict[index].astype(np.float32).tobytes()
item_metadata[vector_field_name]=item_keywords_vector
# HSET
p.hset(key,mapping=item_metadata)
p.execute()
def create_flat_index (redis_conn,vector_field_name,number_of_vectors, vector_dimensions=512, distance_metric='L2'):
redis_conn.ft().create_index([
VectorField(vector_field_name, "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }),
TagField("company_l_id"),
TextField("company_name"),
TextField("item_keywords"),
TagField("industry")
])
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
TEXT_EMBEDDING_DIMENSION=768
NUMBER_COMPANIES=1000
print ('Loading and Indexing + ' + str(NUMBER_COMPANIES) + 'companies')
#flush all data
redis_conn.flushall()
#create flat index & load vectors
create_flat_index(redis_conn, ITEM_KEYWORD_EMBEDDING_FIELD,NUMBER_COMPANIES,TEXT_EMBEDDING_DIMENSION,'COSINE')
load_vectors(redis_conn,company_metadata,item_keywords_vectors,ITEM_KEYWORD_EMBEDDING_FIELD)
'''
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
TEXT_EMBEDDING_DIMENSION=768
NUMBER_PRODUCTS=1000
prompt = PromptTemplate(
input_variables=["company_description"],
template='Create comma seperated company keywords to perform a query on a company dataset for this user input'
)
template = """You are a chatbot. Be kind, detailed and nice. Present the given queried search result in a nice way as answer to the user input. dont ask questions back! just take the given context
{chat_history}
Human: {user_question}
Chatbot:
"""
prompt = PromptTemplate(
input_variables=["chat_history", "user_question"],
template=template
)
chat_history= ""
def answer(user_input):
llm = GooglePalm(temperature=0, google_api_key=os.environ['PALM'])
chain = LLMChain(llm=llm, prompt=prompt)
keywords = chain.run({'user_question':user_input, 'chat_history':chat_history})
topK=3
#vectorize the query
query_vector = model.encode(keywords).astype(np.float32).tobytes()
q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2)
params_dict = {"vec_param": query_vector}
#Execute the query
results = redis_conn.ft().search(q, query_params = params_dict)
full_result_string = ''
for company in results.docs:
full_result_string += company.id + ' ' + company.item_keywords + "\n\n\n"
memory = ConversationBufferMemory(memory_key="chat_history")
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=False,
memory=memory,
)
print(full_result_string)
ans = llm_chain.predict(user_question= f"{full_result_string} ---\n\n {user_input}")
return ans
demo = gr.Interface(
fn=answer,
inputs=["text"],
outputs=["text"],
title="Ask Sonity",
)
demo.launch(share=True)