monsoon-nlp's picture
query sgpt
0d78964
raw
history blame
2.41 kB
import os
import cohere
import gradio as gr
import numpy as np
import pinecone
import torch
from transformers import AutoModel, AutoTokenizer
co = cohere.Client(os.environ.get('COHERE_API', ''))
pinecone.init(
api_key=os.environ.get('PINECONE_API', ''),
environment=os.environ.get('PINECONE_ENV', '')
)
model = AutoModel.from_pretrained('monsoon-nlp/gpt-nyc')
tokenizer = AutoTokenizer.from_pretrained('monsoon-nlp/gpt-nyc')
zos = np.zeros(4096-1024).tolist()
def list_me(matches):
result = ''
for match in matches:
result += '<li><a target="_blank" href="https://reddit.com/r/AskNYC/comments/' + match['id'] + '">'
result += match['metadata']['question']
result += '</a>'
if 'body' in match['metadata']:
result += '<br/>' + match['metadata']['body']
result += '</li>'
return result.replace('/mini', '/')
def query(question):
# Cohere search
response = co.embed(
model='large',
texts=[question],
)
index = pinecone.Index("gptnyc")
closest = index.query(
top_k=2,
include_metadata=True,
vector=response.embeddings[0],
)
# SGPT search
batch_tokens = tokenizer(
[question],
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float().to(last_hidden_state.device)
)
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
closest_sgpt = index.query(
top_k=2,
include_metadata=True,
namespace="mini",
vector=embeddings[0].tolist() + zos,
)
return '<h3>Cohere</h3><ul>' + list_me(closest['matches']) + '</ul><h3>SGPT</h3><ul>' + list_me(closest_sgpt['matches']) + '</ul>'
iface = gr.Interface(
fn=query,
inputs="text",
outputs="html"
)
iface.launch()