M4A132's picture
Update main.py
c3f5275 verified
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from fastapi.responses import StreamingResponse
import random
import uvicorn
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class Item(BaseModel):
prompt: str
history: list = []
idioma: str = "en"
temperature: float = 0.7
max_new_tokens: int = 900
top_p: float = 0.8
repetition_penalty: float = 1.2
def generate(item: Item, max_tokens: int, file_name: str, temperature: float, top_p: float, repetition_penalty: float):
print("request")
if file_name == "Scripter":
if item.idioma == "br":
with open("Scripter2.txt", 'r') as file:
system_prompt = file.read()
else:
with open("Scripter.txt", 'r') as file:
system_prompt = file.read()
else:
with open(file_name, 'r') as file:
system_prompt = file.read()
# Garantir que todos os itens em item.history sejam strings
history_str = "\n".join(str(entry) for entry in item.history)
full_prompt = f"<s>[SYSTEM]{system_prompt}{history_str}\n{item.prompt} [MIA]"
stream = client.text_generation(full_prompt,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(1, 1000),
stream=True,
details=True,
return_full_text=False
)
for response in stream:
yield response.token.text.encode('utf-8')
# Endpoint para 'init_en/' com parâmetros ajustados
@app.post("/initen/")
async def generate1(item: Item):
# Valores específicos para o endpoint 'initen'
temperature = 0.5
max_tokens = 170
top_p = 0.9
repetition_penalty = 1.2
return StreamingResponse(generate(item, max_tokens=max_tokens, file_name='init_en.txt',
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty),
media_type="application/octet-stream")
# Endpoint para 'initbr/' com parâmetros ajustados
@app.post("/initbr/")
async def generate1(item: Item):
# Valores específicos para o endpoint 'initbr'
temperature = 0.5
max_tokens = 170
top_p = 0.9
repetition_penalty = 1.2
return StreamingResponse(generate(item, max_tokens=max_tokens, file_name='init_br.txt',
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty),
media_type="application/octet-stream")
@app.post("/generate1/")
async def generate1(item: Item):
return StreamingResponse(generate(item, max_tokens=350, file_name='Mia.txt',
temperature=item.temperature, top_p=item.top_p,
repetition_penalty=item.repetition_penalty),
media_type="application/octet-stream")
@app.post("/generate2/")
async def generate2(item: Item):
return StreamingResponse(generate(item, max_tokens=350, file_name='Mia.txt',
temperature=item.temperature, top_p=item.top_p,
repetition_penalty=item.repetition_penalty),
media_type="application/octet-stream")
@app.post("/generate3/")
async def generate3_br(item: Item):
return StreamingResponse(generate(item, max_tokens=600, file_name='Scripter',
temperature=item.temperature, top_p=item.top_p,
repetition_penalty=item.repetition_penalty),
media_type="application/octet-stream")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)