Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,2 +1,45 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from typing import Any, Optional, Dict, List
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
from langchain.llms.base import LLM
|
6 |
+
|
7 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
8 |
+
|
9 |
+
|
10 |
+
class KwArgsModel(BaseModel):
|
11 |
+
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
12 |
+
|
13 |
+
class CustomInferenceClient(LLM, KwArgsModel):
|
14 |
+
model_name: str
|
15 |
+
inference_client: InferenceClient
|
16 |
+
|
17 |
+
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
|
18 |
+
inference_client = InferenceClient(model=model_name, token=hf_token)
|
19 |
+
super().__init__(
|
20 |
+
model_name=model_name,
|
21 |
+
hf_token=hf_token,
|
22 |
+
kwargs=kwargs,
|
23 |
+
inference_client=inference_client
|
24 |
+
)
|
25 |
+
|
26 |
+
def _call(
|
27 |
+
self,
|
28 |
+
prompt: str,
|
29 |
+
stop: Optional[List[str]] = None
|
30 |
+
) -> str:
|
31 |
+
if stop is not None:
|
32 |
+
raise ValueError("stop kwargs are not permitted.")
|
33 |
+
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
|
34 |
+
response = ''.join(response_gen)
|
35 |
+
return response
|
36 |
+
|
37 |
+
@property
|
38 |
+
def _llm_type(self) -> str:
|
39 |
+
return "custom"
|
40 |
+
|
41 |
+
@property
|
42 |
+
def _identifying_params(self) -> dict:
|
43 |
+
return {"model_name": self.model_name}
|
44 |
+
|
45 |
+
kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
|