Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
from fastapi.responses import StreamingResponse | |
import random | |
import uvicorn | |
app = FastAPI() | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
class Item(BaseModel): | |
prompt: str | |
history: list = [] | |
idioma: str = "en" | |
temperature: float = 0.7 | |
max_new_tokens: int = 900 | |
top_p: float = 0.8 | |
repetition_penalty: float = 1.2 | |
def generate(item: Item, max_tokens: int, file_name: str, temperature: float, top_p: float, repetition_penalty: float): | |
print("request") | |
if file_name == "Scripter": | |
if item.idioma == "br": | |
with open("Scripter2.txt", 'r') as file: | |
system_prompt = file.read() | |
else: | |
with open("Scripter.txt", 'r') as file: | |
system_prompt = file.read() | |
else: | |
with open(file_name, 'r') as file: | |
system_prompt = file.read() | |
# Garantir que todos os itens em item.history sejam strings | |
history_str = "\n".join(str(entry) for entry in item.history) | |
full_prompt = f"<s>[SYSTEM]{system_prompt}{history_str}\n{item.prompt} [MIA]" | |
stream = client.text_generation(full_prompt, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=random.randint(1, 1000), | |
stream=True, | |
details=True, | |
return_full_text=False | |
) | |
for response in stream: | |
yield response.token.text.encode('utf-8') | |
# Endpoint para 'init_en/' com parâmetros ajustados | |
async def generate1(item: Item): | |
# Valores específicos para o endpoint 'initen' | |
temperature = 0.5 | |
max_tokens = 170 | |
top_p = 0.9 | |
repetition_penalty = 1.2 | |
return StreamingResponse(generate(item, max_tokens=max_tokens, file_name='init_en.txt', | |
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty), | |
media_type="application/octet-stream") | |
# Endpoint para 'initbr/' com parâmetros ajustados | |
async def generate1(item: Item): | |
# Valores específicos para o endpoint 'initbr' | |
temperature = 0.5 | |
max_tokens = 170 | |
top_p = 0.9 | |
repetition_penalty = 1.2 | |
return StreamingResponse(generate(item, max_tokens=max_tokens, file_name='init_br.txt', | |
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty), | |
media_type="application/octet-stream") | |
async def generate1(item: Item): | |
return StreamingResponse(generate(item, max_tokens=350, file_name='Mia.txt', | |
temperature=item.temperature, top_p=item.top_p, | |
repetition_penalty=item.repetition_penalty), | |
media_type="application/octet-stream") | |
async def generate2(item: Item): | |
return StreamingResponse(generate(item, max_tokens=350, file_name='Mia.txt', | |
temperature=item.temperature, top_p=item.top_p, | |
repetition_penalty=item.repetition_penalty), | |
media_type="application/octet-stream") | |
async def generate3_br(item: Item): | |
return StreamingResponse(generate(item, max_tokens=600, file_name='Scripter', | |
temperature=item.temperature, top_p=item.top_p, | |
repetition_penalty=item.repetition_penalty), | |
media_type="application/octet-stream") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |