File size: 3,407 Bytes
ebd06cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Any, List, Mapping, Optional, Dict 
from pydantic import Extra, Field #, root_validator, model_validator
import os,json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import requests


class HostedLLM(LLM):
    """
       Hosted LLMs in huggingface spaces with fastAPI. Interface is primarily rest call with huggingface token

       Attributes:
       url: is the url of the endpoint 
       http_method: which http_method need to be invoked [get or post]
       model_name: which model is being hosted
       temperature: temperature between 0 to 1
       max_tokens: amount of output to generate, 512 by default
       api_token: api_token to be passed for bearer authorization. Defaults to huggingface_api enviorment variable.
       verbose: for extra logging
    """
    url: str = ""
    http_method: Optional[str] = "post"
    model_name: Optional[str] = "bard"
    api_token: Optional[str] = os.environ["HUGGINGFACE_API"]
    temperature: float = 0.7
    max_tokens: int = 512
    verbose: Optional[bool] = False
    class Config:
        extra = Extra.forbid

    #@model_validator(mode="after")
    #def validate_environment(cls, values: Dict) -> Dict:
    #    if values["http_method"].strip() == "GET" or values["http_method"].strip() == "get":
    #        values["http_method"]="get"
    #    else:
    #        values["http_method"]="post"
    #    if values["api_token"] == "":
    #        values["api_token"] = os.environ["HUGGINGFACE_API"]
    #
    #    return values
    
    @property
    def _llm_type(self) -> str:
        return "text2text-generation"
    
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        if run_manager:
            run_manager.on_text([prompt])
        #messages={"messages":[{"role":"user","content":prompt}]}
        prompt={"prompt":prompt}
        headers = {
            "Authorization": f"Bearer {self.api_token}",
            "Content-Type": "application/json",
        }
        if(self.http_method=="post"):
            response=requests.post(self.url,json=prompt,headers=headers)
        else:
            response=requests.get(self.url,json=prompt,headers=headers)
        val=json.loads(response.text)['content']
        if run_manager:
            run_manager.on_llm_end(val)
        return val
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"name": self.model_name, "type": "hosted"}
    
    def extractJson(self,val:str) -> Any:
        """Helper function to extract json from this LLMs output"""
        #This is assuming the json is the first item within ````
        #my super LLM will sometime send the json directly
        try:
            v3=val.replace("\n","").replace("\r","")
            v4=json.loads(v3)
        except:
            v2=val.replace("```json","```").split("```")[1]
            v3=v2.replace("\n","").replace("\r","")
            v4=json.loads(v3)
        return v4
    
    def extractPython(self,val:str) -> Any:
        """Helper function to extract python from this LLMs output"""
        #This is assuming the python is the first item within ````
        v2=val.replace("```python","```").split("```")[1]
        return v2