File size: 2,250 Bytes
f84e083
caa64e7
f84e083
 
 
9441c54
0f34bf3
bc5e3f5
 
f84e083
 
 
f2b775d
 
f84e083
 
 
 
 
e40242b
a0ed03b
f84e083
 
 
 
f2b775d
 
 
 
 
 
 
 
 
 
bc5e3f5
 
 
 
f84e083
9441c54
f84e083
9441c54
 
 
 
 
 
 
 
 
 
 
bc5e3f5
d0c61b6
215f4a9
bc5e3f5
215f4a9
d0c61b6
f84e083
 
 
9441c54
d0c61b6
9441c54
bc5e3f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json  # Asegúrate de que esta línea esté al principio del archivo
import torch


app = FastAPI()

# Initialize the InferenceClient with the Gemma-7b model
client = InferenceClient("google/gemma-7b")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 8000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(message, history):
    prompt = "<bos>"
    # Add history to the prompt if there's any
    if history:
        for entry in history:
            role = "user" if entry['role'] == "user" else "model"
            prompt += f"<start_of_turn>{role}\n{entry['content']}<end_of_turn>"
    # Add the current message
    prompt += f"<start_of_turn>user\n{message}<end_of_turn><start_of_turn>model\n"
    return prompt




# No changes needed in the format_prompt function unless the new model requires different prompt formatting

def generate_stream(item: Item) -> Generator[bytes, None, None]:
    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    generate_kwargs = {
        "temperature": item.temperature,
        "max_new_tokens": item.max_new_tokens,
        "top_p": item.top_p,
        "repetition_penalty": item.repetition_penalty,
        "do_sample": True,
        "seed": 42,  # Adjust or omit the seed as needed
    }

    # Stream the response from the InferenceClient
    for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
        # Check if the 'details' flag and response structure are the same for the new model
        chunk = {
            "text": response.token.text,
            "complete": response.generated_text is not None
        }
        yield json.dumps(chunk).encode("utf-8") + b"\n"

@app.post("/generate/")
async def generate_text(item: Item):
    return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)