Spaces:
Runtime error
Runtime error
from typing import Any, List, Mapping, Optional, Dict | |
from pydantic import Extra, Field #, root_validator, model_validator | |
import os,json | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
import requests | |
class HostedLLM(LLM): | |
""" | |
Hosted LLMs in huggingface spaces with fastAPI. Interface is primarily rest call with huggingface token | |
Attributes: | |
url: is the url of the endpoint | |
http_method: which http_method need to be invoked [get or post] | |
model_name: which model is being hosted | |
temperature: temperature between 0 to 1 | |
max_tokens: amount of output to generate, 512 by default | |
api_token: api_token to be passed for bearer authorization. Defaults to huggingface_api enviorment variable. | |
verbose: for extra logging | |
""" | |
url: str = "" | |
http_method: Optional[str] = "post" | |
model_name: Optional[str] = "bard" | |
api_token: Optional[str] = os.environ["HUGGINGFACE_API"] | |
temperature: float = 0.7 | |
max_tokens: int = 512 | |
verbose: Optional[bool] = False | |
class Config: | |
extra = Extra.forbid | |
#@model_validator(mode="after") | |
#def validate_environment(cls, values: Dict) -> Dict: | |
# if values["http_method"].strip() == "GET" or values["http_method"].strip() == "get": | |
# values["http_method"]="get" | |
# else: | |
# values["http_method"]="post" | |
# if values["api_token"] == "": | |
# values["api_token"] = os.environ["HUGGINGFACE_API"] | |
# | |
# return values | |
def _llm_type(self) -> str: | |
return "text2text-generation" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
if run_manager: | |
run_manager.on_text([prompt]) | |
#messages={"messages":[{"role":"user","content":prompt}]} | |
prompt={"prompt":prompt} | |
headers = { | |
"Authorization": f"Bearer {self.api_token}", | |
"Content-Type": "application/json", | |
} | |
if(self.http_method=="post"): | |
response=requests.post(self.url,json=prompt,headers=headers) | |
else: | |
response=requests.get(self.url,json=prompt,headers=headers) | |
val=json.loads(response.text)['content'] | |
if run_manager: | |
run_manager.on_llm_end(val) | |
return val | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"name": self.model_name, "type": "hosted"} | |
def extractJson(self,val:str) -> Any: | |
"""Helper function to extract json from this LLMs output""" | |
#This is assuming the json is the first item within ```` | |
#my super LLM will sometime send the json directly | |
try: | |
v3=val.replace("\n","").replace("\r","") | |
v4=json.loads(v3) | |
except: | |
v2=val.replace("```json","```").split("```")[1] | |
v3=v2.replace("\n","").replace("\r","") | |
v4=json.loads(v3) | |
return v4 | |
def extractPython(self,val:str) -> Any: | |
"""Helper function to extract python from this LLMs output""" | |
#This is assuming the python is the first item within ```` | |
v2=val.replace("```python","```").split("```")[1] | |
return v2 |