cahya's picture
add option for 8bit
84d80f5
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
from fastapi import Form, Depends, HTTPException, status
from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import torch
import os
import time
import re
import json
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
// var ws = new WebSocket("ws://localhost:8000/api/ws");
var ws = new WebSocket("wss://cahya-indonesian-whisperer.hf.space/api/ws");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
@app.get("/")
async def get():
return HTMLResponse(html)
@app.get("/api/env")
async def env():
environment_variables = "<h3>Environment Variables</h3>"
for name, value in os.environ.items():
environment_variables += f"{name}: {value}<br>"
return HTMLResponse(environment_variables)
@app.websocket("/api/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
@app.post("/api/indochat/v1")
async def indochat(**kwargs):
return text_generate("indochat-tiny", kwargs)
@app.post("/api/text-generator/v1")
async def text_generate(
model_name: str = Form(default="", description="The model name"),
text: str = Form(default="", description="The Prompt"),
decoding_method: str = Form(default="Sampling", description="Decoding method"),
min_length: int = Form(default=50, description="Minimal length of the generated text"),
max_length: int = Form(default=250, description="Maximal length of the generated text"),
num_beams: int = Form(default=5, description="Beams number"),
top_k: int = Form(default=30, description="The number of highest probability vocabulary tokens to keep "
"for top-k-filtering"),
top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
"probabilities that add up to top_p or higher are kept "
"for generation"),
temperature: float = Form(default=0.5, description="The Temperature of the softmax distribution"),
penalty_alpha: float = Form(default=0.5, description="Penalty alpha"),
repetition_penalty: float = Form(default=1.2, description="Repetition penalty"),
seed: int = Form(default=-1, description="Random Seed"),
max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
):
if seed >= 0:
set_seed(seed)
if decoding_method == "Beam Search":
do_sample = False
penalty_alpha = 0
elif decoding_method == "Sampling":
do_sample = True
penalty_alpha = 0
num_beams = 1
else:
do_sample = False
num_beams = 1
if repetition_penalty == 0.0:
min_penalty = 1.05
max_penalty = 1.5
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
prompt = f"User: {text}\nAssistant: "
input_ids = text_generator[model_name]["tokenizer"](prompt, return_tensors='pt').input_ids.to(0)
text_generator[model_name]["model"].eval()
print("Generating text...")
print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, "
f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}")
time_start = time.time()
sample_outputs = text_generator[model_name]["model"].generate(input_ids,
penalty_alpha=penalty_alpha,
do_sample=do_sample,
num_beams=num_beams,
min_length=min_length,
max_length=max_length,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
max_time=max_time
)
result = text_generator[model_name]["tokenizer"].decode(sample_outputs[0], skip_special_tokens=True)
time_end = time.time()
time_diff = time_end - time_start
print(f"result:\n{result}")
generated_text = result[len(prompt)+1:]
generated_text = generated_text[:generated_text.find("User:")]
return {"generated_text": generated_text, "processing_time": time_diff}
def get_text_generator(model_name: str, load_in_8bit: bool = False, device: str = "cpu"):
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
print(f"hf_auth_token: {hf_auth_token}")
print(f"Loading model with device: {device}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
load_in_8bit=load_in_8bit, device_map="auto", use_auth_token=hf_auth_token)
# model.to(device)
print("Model loaded")
return model, tokenizer
def get_config():
return json.load(open("config.json", "r"))
config = get_config()
device = "cuda" if torch.cuda.is_available() else "cpu"
text_generator = {}
for model_name in config["text-generator"]:
model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name]["name"],
load_in_8bit=config["text-generator"][model_name]["load_in_8bit"],
device=device)
text_generator[model_name] = {
"model": model,
"tokenizer": tokenizer
}