Spaces:
Sleeping
Sleeping
File size: 2,250 Bytes
f84e083 caa64e7 f84e083 9441c54 0f34bf3 bc5e3f5 f84e083 f2b775d f84e083 e40242b a0ed03b f84e083 f2b775d bc5e3f5 f84e083 9441c54 f84e083 9441c54 bc5e3f5 d0c61b6 215f4a9 bc5e3f5 215f4a9 d0c61b6 f84e083 9441c54 d0c61b6 9441c54 bc5e3f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json # Asegúrate de que esta línea esté al principio del archivo
import torch
app = FastAPI()
# Initialize the InferenceClient with the Gemma-7b model
client = InferenceClient("google/gemma-7b")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.8
max_new_tokens: int = 8000
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(message, history):
prompt = "<bos>"
# Add history to the prompt if there's any
if history:
for entry in history:
role = "user" if entry['role'] == "user" else "model"
prompt += f"<start_of_turn>{role}\n{entry['content']}<end_of_turn>"
# Add the current message
prompt += f"<start_of_turn>user\n{message}<end_of_turn><start_of_turn>model\n"
return prompt
# No changes needed in the format_prompt function unless the new model requires different prompt formatting
def generate_stream(item: Item) -> Generator[bytes, None, None]:
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
generate_kwargs = {
"temperature": item.temperature,
"max_new_tokens": item.max_new_tokens,
"top_p": item.top_p,
"repetition_penalty": item.repetition_penalty,
"do_sample": True,
"seed": 42, # Adjust or omit the seed as needed
}
# Stream the response from the InferenceClient
for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
# Check if the 'details' flag and response structure are the same for the new model
chunk = {
"text": response.token.text,
"complete": response.generated_text is not None
}
yield json.dumps(chunk).encode("utf-8") + b"\n"
@app.post("/generate/")
async def generate_text(item: Item):
return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |