maya-persistence / src /llm /hostedLLM.py
anubhav77's picture
v0.1
ebd06cc
from typing import Any, List, Mapping, Optional, Dict
from pydantic import Extra, Field #, root_validator, model_validator
import os,json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import requests
class HostedLLM(LLM):
"""
Hosted LLMs in huggingface spaces with fastAPI. Interface is primarily rest call with huggingface token
Attributes:
url: is the url of the endpoint
http_method: which http_method need to be invoked [get or post]
model_name: which model is being hosted
temperature: temperature between 0 to 1
max_tokens: amount of output to generate, 512 by default
api_token: api_token to be passed for bearer authorization. Defaults to huggingface_api enviorment variable.
verbose: for extra logging
"""
url: str = ""
http_method: Optional[str] = "post"
model_name: Optional[str] = "bard"
api_token: Optional[str] = os.environ["HUGGINGFACE_API"]
temperature: float = 0.7
max_tokens: int = 512
verbose: Optional[bool] = False
class Config:
extra = Extra.forbid
#@model_validator(mode="after")
#def validate_environment(cls, values: Dict) -> Dict:
# if values["http_method"].strip() == "GET" or values["http_method"].strip() == "get":
# values["http_method"]="get"
# else:
# values["http_method"]="post"
# if values["api_token"] == "":
# values["api_token"] = os.environ["HUGGINGFACE_API"]
#
# return values
@property
def _llm_type(self) -> str:
return "text2text-generation"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if run_manager:
run_manager.on_text([prompt])
#messages={"messages":[{"role":"user","content":prompt}]}
prompt={"prompt":prompt}
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
if(self.http_method=="post"):
response=requests.post(self.url,json=prompt,headers=headers)
else:
response=requests.get(self.url,json=prompt,headers=headers)
val=json.loads(response.text)['content']
if run_manager:
run_manager.on_llm_end(val)
return val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"name": self.model_name, "type": "hosted"}
def extractJson(self,val:str) -> Any:
"""Helper function to extract json from this LLMs output"""
#This is assuming the json is the first item within ````
#my super LLM will sometime send the json directly
try:
v3=val.replace("\n","").replace("\r","")
v4=json.loads(v3)
except:
v2=val.replace("```json","```").split("```")[1]
v3=v2.replace("\n","").replace("\r","")
v4=json.loads(v3)
return v4
def extractPython(self,val:str) -> Any:
"""Helper function to extract python from this LLMs output"""
#This is assuming the python is the first item within ````
v2=val.replace("```python","```").split("```")[1]
return v2