File size: 2,684 Bytes
0b9c508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f94144
0b9c508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json  # Asegúrate de que esta línea esté al principio del archivo
import nltk
import os


nltk.data.path.append(os.getenv('NLTK_DATA'))

app = FastAPI()

# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")

# summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")


class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 12000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(current_prompt, history):
    formatted_history = "<s>"
    for entry in history:
        if entry["role"] == "user":
            formatted_history += f"[USER] {entry['content']} [/USER]"
        elif entry["role"] == "assistant":
            formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
    formatted_history += f"[USER] {current_prompt} [/USER]</s>"
    return formatted_history


def generate_stream(item: Item) -> Generator[bytes, None, None]:
    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    # Estimate token count for the formatted_prompt
    input_token_count = len(nltk.word_tokenize(formatted_prompt))  # NLTK tokenization

    # Ensure total token count doesn't exceed the maximum limit
    max_tokens_allowed = 32768
    max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))

    generate_kwargs = {
        "temperature": item.temperature,
        "max_new_tokens": max_new_tokens_adjusted,
        "top_p": item.top_p,
        "repetition_penalty": item.repetition_penalty,
        "do_sample": True,
        "seed": 42,
    }

    # Stream the response from the InferenceClient
    for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
        # This assumes 'details=True' gives you a structure where you can access the text like this
        chunk = {
            "text": response.token.text,
            "complete": response.generated_text is not None  # Adjust based on how you detect completion
        }
        yield json.dumps(chunk).encode("utf-8") + b"\n"

@app.post("/generate/")
async def generate_text(item: Item):
    # Stream response back to the client
    return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)