Marroco93 commited on
Commit
ce8dee8
·
1 Parent(s): a0ed03b

mistralai again

Browse files
Files changed (1) hide show
  1. main.py +12 -21
main.py CHANGED
@@ -5,39 +5,29 @@ from huggingface_hub import InferenceClient
5
  import uvicorn
6
  from typing import Generator
7
  import json # Asegúrate de que esta línea esté al principio del archivo
8
- import torch
9
-
10
 
11
  app = FastAPI()
12
 
13
- # Initialize the InferenceClient with the Gemma-7b model
14
- client = InferenceClient("google/gemma-7b")
15
 
16
  class Item(BaseModel):
17
  prompt: str
18
  history: list
19
  system_prompt: str
20
  temperature: float = 0.8
21
- max_new_tokens: int = 8000
22
  top_p: float = 0.15
23
  repetition_penalty: float = 1.0
24
 
25
  def format_prompt(message, history):
26
- prompt = "<bos>"
27
- # Add history to the prompt if there's any
28
- if history:
29
- for entry in history:
30
- role = "user" if entry['role'] == "user" else "model"
31
- prompt += f"<start_of_turn>{role}\n{entry['content']}<end_of_turn>"
32
- # Add the current message
33
- prompt += f"<start_of_turn>user\n{message}<end_of_turn><start_of_turn>model\n"
34
  return prompt
35
 
36
-
37
-
38
-
39
- # No changes needed in the format_prompt function unless the new model requires different prompt formatting
40
-
41
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
42
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
43
  generate_kwargs = {
@@ -51,16 +41,17 @@ def generate_stream(item: Item) -> Generator[bytes, None, None]:
51
 
52
  # Stream the response from the InferenceClient
53
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
54
- # Check if the 'details' flag and response structure are the same for the new model
55
  chunk = {
56
  "text": response.token.text,
57
- "complete": response.generated_text is not None
58
  }
59
  yield json.dumps(chunk).encode("utf-8") + b"\n"
60
 
61
  @app.post("/generate/")
62
  async def generate_text(item: Item):
 
63
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
64
 
65
  if __name__ == "__main__":
66
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  import uvicorn
6
  from typing import Generator
7
  import json # Asegúrate de que esta línea esté al principio del archivo
 
 
8
 
9
  app = FastAPI()
10
 
11
+ # Initialize the InferenceClient with your model
12
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
13
 
14
  class Item(BaseModel):
15
  prompt: str
16
  history: list
17
  system_prompt: str
18
  temperature: float = 0.8
19
+ max_new_tokens: int = 9000
20
  top_p: float = 0.15
21
  repetition_penalty: float = 1.0
22
 
23
  def format_prompt(message, history):
24
+ prompt = "<s>"
25
+ for user_prompt, bot_response in history:
26
+ prompt += f"[INST] {user_prompt} [/INST]"
27
+ prompt += f" {bot_response}</s> "
28
+ prompt += f"[INST] {message} [/INST]"
 
 
 
29
  return prompt
30
 
 
 
 
 
 
31
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
32
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
33
  generate_kwargs = {
 
41
 
42
  # Stream the response from the InferenceClient
43
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
44
+ # This assumes 'details=True' gives you a structure where you can access the text like this
45
  chunk = {
46
  "text": response.token.text,
47
+ "complete": response.generated_text is not None # Adjust based on how you detect completion
48
  }
49
  yield json.dumps(chunk).encode("utf-8") + b"\n"
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
+ # Stream response back to the client
54
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
55
 
56
  if __name__ == "__main__":
57
+ uvicorn.run(app, host="0.0.0.0", port=8000)