msy127 commited on
Commit
2b86939
·
1 Parent(s): dac0a0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -2
app.py CHANGED
@@ -1,2 +1,45 @@
1
- import sys
2
- print(sys.version)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pydantic import BaseModel, Field
3
+ from typing import Any, Optional, Dict, List
4
+ from huggingface_hub import InferenceClient
5
+ from langchain.llms.base import LLM
6
+
7
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
8
+
9
+
10
+ class KwArgsModel(BaseModel):
11
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
12
+
13
+ class CustomInferenceClient(LLM, KwArgsModel):
14
+ model_name: str
15
+ inference_client: InferenceClient
16
+
17
+ def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
18
+ inference_client = InferenceClient(model=model_name, token=hf_token)
19
+ super().__init__(
20
+ model_name=model_name,
21
+ hf_token=hf_token,
22
+ kwargs=kwargs,
23
+ inference_client=inference_client
24
+ )
25
+
26
+ def _call(
27
+ self,
28
+ prompt: str,
29
+ stop: Optional[List[str]] = None
30
+ ) -> str:
31
+ if stop is not None:
32
+ raise ValueError("stop kwargs are not permitted.")
33
+ response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
34
+ response = ''.join(response_gen)
35
+ return response
36
+
37
+ @property
38
+ def _llm_type(self) -> str:
39
+ return "custom"
40
+
41
+ @property
42
+ def _identifying_params(self) -> dict:
43
+ return {"model_name": self.model_name}
44
+
45
+ kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}