protein-retrieval-base / models /biomed_model.py
lindsay-qu's picture
Upload 86 files
e0f406c verified
from .base_model import BaseModel
import openai
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
class BiomedModel(BaseModel):
def __init__(self,
generation_model="gpt-4",
embedding_model="pritamdeka/S-PubMedBert-MS-MARCO",
temperature=0,
) -> None:
self.generation_model = generation_model
self.embedding_model = SentenceTransformer(embedding_model)
self.temperature = temperature
def respond(self, messages: str) -> str:
response = openai.ChatCompletion.create(
messages=messages,
model=self.generation_model,
temperature=self.temperature,
).choices[0]['message']['content']
return response
def embedding(self, texts: list) -> list:
if len(texts) == 1:
return self.embedding_model.encode(texts[0]).tolist()
else:
data = self.embedding_model.encode(texts, show_progress_bar=True)
data = [d.tolist() for d in data]
return data