Hansimov commited on
Commit
9f341cc
1 Parent(s): 9bc229a

:gem: [Feature] New MessageStreamer: Enable requests inference api with requests

Browse files
networks/__init__.py ADDED
File without changes
networks/message_streamer.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import requests
4
+ from messagers.message_outputer import OpenaiStreamOutputer
5
+ from utils.logger import logger
6
+ from utils.enver import enver
7
+ from huggingface_hub import InferenceClient
8
+
9
+
10
+ class MessageStreamer:
11
+ MODEL_MAP = {
12
+ "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
13
+ }
14
+
15
+ def __init__(self, model: str):
16
+ self.model = model
17
+ self.model_fullname = self.MODEL_MAP[model]
18
+
19
+ def parse_line(self, line):
20
+ line = line.decode("utf-8")
21
+ line = re.sub(r"data:\s*", "", line)
22
+ data = json.loads(line)
23
+ content = data["token"]["text"]
24
+ return content
25
+
26
+ def chat(
27
+ self,
28
+ prompt: str = None,
29
+ temperature: float = 0.01,
30
+ max_new_tokens: int = 32000,
31
+ stream: bool = True,
32
+ yield_output: bool = False,
33
+ ):
34
+ # https://huggingface.co/docs/text-generation-inference/conceptual/streaming#streaming-with-curl
35
+ self.request_url = (
36
+ f"https://api-inference.huggingface.co/models/{self.model_fullname}"
37
+ )
38
+ self.message_outputer = OpenaiStreamOutputer()
39
+ self.request_headers = {
40
+ "Content-Type": "application/json",
41
+ }
42
+ # huggingface_hub/inference/_client.py: class InferenceClient > def text_generation()
43
+ self.request_body = {
44
+ "inputs": prompt,
45
+ "parameters": {
46
+ "temperature": temperature,
47
+ "max_new_tokens": max_new_tokens,
48
+ "return_full_text": False,
49
+ },
50
+ "stream": stream,
51
+ }
52
+ print(self.request_url)
53
+ enver.set_envs(proxies=True)
54
+ stream = requests.post(
55
+ self.request_url,
56
+ headers=self.request_headers,
57
+ json=self.request_body,
58
+ proxies=enver.requests_proxies,
59
+ stream=stream,
60
+ )
61
+ print(stream.status_code)
62
+ for line in stream.iter_lines():
63
+ if not line:
64
+ continue
65
+
66
+ content = self.parse_line(line)
67
+
68
+ if content.strip() == "</s>":
69
+ content_type = "Finished"
70
+ logger.mesg("\n[Finished]")
71
+ else:
72
+ content_type = "Completions"
73
+ logger.mesg(content, end="")
74
+
75
+ if yield_output:
76
+ output = self.message_outputer.output(
77
+ content=content, content_type=content_type
78
+ )
79
+ yield output