File size: 2,565 Bytes
f84e083
caa64e7
f84e083
 
 
9441c54
f12ecf0
7a31970
4849bdc
f12ecf0
4849bdc
f12ecf0
4849bdc
f84e083
f12ecf0
f84e083
 
ce8dee8
 
f84e083
f12ecf0
 
 
f84e083
 
 
 
e40242b
245c296
f84e083
 
 
f12ecf0
 
 
 
 
 
 
5b8435c
 
f12ecf0
5b8435c
 
 
9441c54
f12ecf0
 
 
4cc4589
 
 
 
9441c54
 
4cc4589
9441c54
 
 
4cc4589
9441c54
 
 
d0c61b6
215f4a9
f12ecf0
215f4a9
d0c61b6
f84e083
 
 
9441c54
d0c61b6
9441c54
ce8dee8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json
import nltk
import os
from transformers import pipeline

# Set up the environment for NLTK
nltk.data.path.append(os.getenv('NLTK_DATA'))

# Initialize the FastAPI app
app = FastAPI()

# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")

# Initialize the summarization pipeline
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 12000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def summarize_history(history):
    # Concatenate all history entries into a single string
    full_history = " ".join(entry['content'] for entry in history if entry['role'] == 'user')
    # Summarize the history
    summarized_history = summarizer(full_history, max_length=1024, truncation=True)
    return summarized_history[0]['summary_text']

def format_prompt(current_prompt, history):
    formatted_history = "<s>"
    formatted_history += f"[HISTORY] {history} [/HISTORY]"
    formatted_history += f"[USER] {current_prompt} [/USER]</s>"
    return formatted_history

def generate_stream(item: Item) -> Generator[bytes, None, None]:
    summarized_history = summarize_history(item.history)
    formatted_prompt = format_prompt(item.prompt, summarized_history)
    input_token_count = len(nltk.word_tokenize(formatted_prompt))

    max_tokens_allowed = 32768
    max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))

    generate_kwargs = {
        "temperature": item.temperature,
        "max_new_tokens": max_new_tokens_adjusted,
        "top_p": item.top_p,
        "repetition_penalty": item.repetition_penalty,
        "do_sample": True,
        "seed": 42,
    }

    for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
        chunk = {
            "text": response.token.text,
            "complete": response.generated_text is not None
        }
        yield json.dumps(chunk).encode("utf-8") + b"\n"

@app.post("/generate/")
async def generate_text(item: Item):
    return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)