import gradio as gr import openai import pandas as pd import numpy as np import csv from datasets import load_dataset openai.api_key="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" from openai.embeddings_utils import get_embedding from openai.embeddings_utils import cosine_similarity import requests model_id = "sentence-transformers/all-MiniLM-L6-v2" import json hf_token = "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" import re from sklearn.metrics.pairwise import cosine_similarity def generate_embeddings(texts, model_id, hf_token): api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" headers = {"Authorization": f"Bearer {hf_token}"} response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) embeddings = response.json() return embeddings Bio_embeddings = load_dataset('vjain/biology_AP_embeddings') df = pd.DataFrame(Bio_embeddings['train']) #df = pd.read_csv("TA_embeddings.csv") #df["embedding"]=df["embedding"].apply(eval).apply(np.array) def reply(input): input = input input_vector = get_embedding(input, mdoel_id,hf_token) df["similiarities"]=df["embedding"].apply(lambda x: cosine_similarity([x],[input_vector])[0][0]) data = df.sort_values("similiarities", ascending=False).head(20) data.to_csv("sorted.csv") context = [] for i, row in data.iterrows(): context.append(row['text']) context text = "\n".join(context) context = text prompt = f""" Answer the following question If you don't know the answer for certain, say I don't know. Context: {context} Q: {input} """ return openai.Completion.create( prompt=prompt, temperature=1, max_tokens=500, top_p=1, frequency_penalty=0, presence_penalty=0, model="text-davinci-003" )["choices"][0]["text"].strip(" \n") input_text = gr.inputs.Textbox(label="Enter your Trading questions here") text_output = gr.outputs.Textbox(label="Answer") ui = gr.Interface(fn=reply, inputs=input_text, outputs=[text_output], theme="compact", layout="vertical", inputs_layout="stacked", outputs_layout="stacked", allow_flagging=False) ui.launch()