medicare-fastapi / load_meta_data.py
Mohammed-Altaf's picture
added writable cache working directory
5189bbe
raw
history blame
No virus
1.84 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pydantic import BaseModel
from fastapi import HTTPException
class UserQuery(BaseModel):
user_query: str
class ChatBot:
def __init__(self):
self.tokenizer = None
self.model = None
def load_from_hub(self,model_id: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_id,
device_map = 'auto')
self.model = AutoModelForCausalLM.from_pretrained(model_id,
device_map='auto')
def get_response(self,text: UserQuery) -> str:
if not self.model or not self.tokenizer:
raise HTTPException(status_code=400, detail="Model is not loaded")
inputs = self.tokenizer(text,return_tensors='pt')
outputs = self.model.generate(**inputs,
max_new_tokens = 100,
# add extra parameters if models runs successfully
)
response = self.tokenizer.decode(outputs[0],skip_special_tokens=True)
response = self.get_clean_response(response)
return response
def get_clean_response(self,response):
if type(response) == list:
response = response[0].split("\n")
else:
response = response.split("\n")
ans = ''
cnt = 0 # to verify if we have seen Human before
for answer in response:
if answer.startswith("[|Human|]"): cnt += 1
elif answer.startswith('[|AI|]'):
answer = answer.split(' ')
ans += ' '.join(char for char in answer[1:])
ans += '\n'
elif cnt:
ans += answer + '\n'
return ans