Spaces:
Sleeping
Sleeping
File size: 2,565 Bytes
f84e083 caa64e7 f84e083 9441c54 f12ecf0 7a31970 4849bdc f12ecf0 4849bdc f12ecf0 4849bdc f84e083 f12ecf0 f84e083 ce8dee8 f84e083 f12ecf0 f84e083 e40242b 245c296 f84e083 f12ecf0 5b8435c f12ecf0 5b8435c 9441c54 f12ecf0 4cc4589 9441c54 4cc4589 9441c54 4cc4589 9441c54 d0c61b6 215f4a9 f12ecf0 215f4a9 d0c61b6 f84e083 9441c54 d0c61b6 9441c54 ce8dee8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json
import nltk
import os
from transformers import pipeline
# Set up the environment for NLTK
nltk.data.path.append(os.getenv('NLTK_DATA'))
# Initialize the FastAPI app
app = FastAPI()
# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
# Initialize the summarization pipeline
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.8
max_new_tokens: int = 12000
top_p: float = 0.15
repetition_penalty: float = 1.0
def summarize_history(history):
# Concatenate all history entries into a single string
full_history = " ".join(entry['content'] for entry in history if entry['role'] == 'user')
# Summarize the history
summarized_history = summarizer(full_history, max_length=1024, truncation=True)
return summarized_history[0]['summary_text']
def format_prompt(current_prompt, history):
formatted_history = "<s>"
formatted_history += f"[HISTORY] {history} [/HISTORY]"
formatted_history += f"[USER] {current_prompt} [/USER]</s>"
return formatted_history
def generate_stream(item: Item) -> Generator[bytes, None, None]:
summarized_history = summarize_history(item.history)
formatted_prompt = format_prompt(item.prompt, summarized_history)
input_token_count = len(nltk.word_tokenize(formatted_prompt))
max_tokens_allowed = 32768
max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
generate_kwargs = {
"temperature": item.temperature,
"max_new_tokens": max_new_tokens_adjusted,
"top_p": item.top_p,
"repetition_penalty": item.repetition_penalty,
"do_sample": True,
"seed": 42,
}
for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
chunk = {
"text": response.token.text,
"complete": response.generated_text is not None
}
yield json.dumps(chunk).encode("utf-8") + b"\n"
@app.post("/generate/")
async def generate_text(item: Item):
return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|