Spaces:
Runtime error
Runtime error
import os | |
import cohere | |
import gradio as gr | |
import numpy as np | |
import pinecone | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
co = cohere.Client(os.environ.get('COHERE_API', '')) | |
pinecone.init( | |
api_key=os.environ.get('PINECONE_API', ''), | |
environment=os.environ.get('PINECONE_ENV', '') | |
) | |
model = AutoModel.from_pretrained('monsoon-nlp/gpt-nyc') | |
tokenizer = AutoTokenizer.from_pretrained('monsoon-nlp/gpt-nyc') | |
zos = np.zeros(4096-1024).tolist() | |
def list_me(matches): | |
result = '' | |
for match in matches: | |
result += '<li><a target="_blank" href="https://reddit.com/r/AskNYC/comments/' + match['id'] + '">' | |
result += match['metadata']['question'] | |
result += '</a>' | |
if 'body' in match['metadata']: | |
result += '<br/>' + match['metadata']['body'] | |
result += '</li>' | |
return result.replace('/mini', '/') | |
def query(question): | |
# Cohere search | |
response = co.embed( | |
model='large', | |
texts=[question], | |
) | |
index = pinecone.Index("gptnyc") | |
closest = index.query( | |
top_k=2, | |
include_metadata=True, | |
vector=response.embeddings[0], | |
) | |
# SGPT search | |
batch_tokens = tokenizer( | |
[question], | |
padding=True, | |
truncation=True, | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state | |
weights = ( | |
torch.arange(start=1, end=last_hidden_state.shape[1] + 1) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.expand(last_hidden_state.size()) | |
.float().to(last_hidden_state.device) | |
) | |
input_mask_expanded = ( | |
batch_tokens["attention_mask"] | |
.unsqueeze(-1) | |
.expand(last_hidden_state.size()) | |
.float() | |
) | |
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1) | |
sum_mask = torch.sum(input_mask_expanded * weights, dim=1) | |
embeddings = sum_embeddings / sum_mask | |
closest_sgpt = index.query( | |
top_k=2, | |
include_metadata=True, | |
namespace="mini", | |
vector=embeddings[0].tolist() + zos, | |
) | |
return '<h3>Cohere</h3><ul>' + list_me(closest['matches']) + '</ul><h3>SGPT</h3><ul>' + list_me(closest_sgpt['matches']) + '</ul>' | |
iface = gr.Interface( | |
fn=query, | |
inputs="text", | |
outputs="html" | |
) | |
iface.launch() | |