Marroco93 commited on
Commit
5b8435c
1 Parent(s): ce8dee8
Files changed (2) hide show
  1. copy_main.py +57 -0
  2. main.py +10 -7
copy_main.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ from huggingface_hub import InferenceClient
5
+ import uvicorn
6
+ from typing import Generator
7
+ import json # Asegúrate de que esta línea esté al principio del archivo
8
+
9
+ app = FastAPI()
10
+
11
+ # Initialize the InferenceClient with your model
12
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
+
14
+ class Item(BaseModel):
15
+ prompt: str
16
+ history: list
17
+ system_prompt: str
18
+ temperature: float = 0.8
19
+ max_new_tokens: int = 9000
20
+ top_p: float = 0.15
21
+ repetition_penalty: float = 1.0
22
+
23
+ def format_prompt(message, history):
24
+ prompt = "<s>"
25
+ for user_prompt, bot_response in history:
26
+ prompt += f"[INST] {user_prompt} [/INST]"
27
+ prompt += f" {bot_response}</s> "
28
+ prompt += f"[INST] {message} [/INST]"
29
+ return prompt
30
+
31
+ def generate_stream(item: Item) -> Generator[bytes, None, None]:
32
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
33
+ generate_kwargs = {
34
+ "temperature": item.temperature,
35
+ "max_new_tokens": item.max_new_tokens,
36
+ "top_p": item.top_p,
37
+ "repetition_penalty": item.repetition_penalty,
38
+ "do_sample": True,
39
+ "seed": 42, # Adjust or omit the seed as needed
40
+ }
41
+
42
+ # Stream the response from the InferenceClient
43
+ for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
44
+ # This assumes 'details=True' gives you a structure where you can access the text like this
45
+ chunk = {
46
+ "text": response.token.text,
47
+ "complete": response.generated_text is not None # Adjust based on how you detect completion
48
+ }
49
+ yield json.dumps(chunk).encode("utf-8") + b"\n"
50
+
51
+ @app.post("/generate/")
52
+ async def generate_text(item: Item):
53
+ # Stream response back to the client
54
+ return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
55
+
56
+ if __name__ == "__main__":
57
+ uvicorn.run(app, host="0.0.0.0", port=8000)
main.py CHANGED
@@ -20,13 +20,16 @@ class Item(BaseModel):
20
  top_p: float = 0.15
21
  repetition_penalty: float = 1.0
22
 
23
- def format_prompt(message, history):
24
- prompt = "<s>"
25
- for user_prompt, bot_response in history:
26
- prompt += f"[INST] {user_prompt} [/INST]"
27
- prompt += f" {bot_response}</s> "
28
- prompt += f"[INST] {message} [/INST]"
29
- return prompt
 
 
 
30
 
31
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
32
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
 
20
  top_p: float = 0.15
21
  repetition_penalty: float = 1.0
22
 
23
+ def format_prompt(current_prompt, history):
24
+ formatted_history = "<s>"
25
+ for entry in history:
26
+ if entry["role"] == "user":
27
+ formatted_history += f"[USER] {entry['content']} [/USER]"
28
+ elif entry["role"] == "assistant":
29
+ formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
30
+ formatted_history += f"[USER] {current_prompt} [/USER]</s>"
31
+ return formatted_history
32
+
33
 
34
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
35
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)