from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse from fastapi import Form, Depends, HTTPException, status from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel import torch import os import time import re import json app = FastAPI() html = """ Chat

WebSocket Chat

""" @app.get("/") async def get(): return HTMLResponse(html) @app.get("/api/env") async def env(): environment_variables = "

Environment Variables

" for name, value in os.environ.items(): environment_variables += f"{name}: {value}
" return HTMLResponse(environment_variables) @app.websocket("/api/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() while True: data = await websocket.receive_text() await websocket.send_text(f"Message text was: {data}") @app.post("/api/indochat/v1") async def indochat( text: str = Form(default="", description="The Prompt"), max_length: int = Form(default=250, description="Maximal length of the generated text"), do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"), top_k: int = Form(default=50, description="The number of highest probability vocabulary tokens to keep " "for top-k-filtering"), top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with " "probabilities that add up to top_p or higher are kept " "for generation"), temperature: float = Form(default=1.0, description="The Temperature of the softmax distribution"), penalty_alpha: float = Form(default=0.6, description="Penalty alpha"), repetition_penalty: float = Form(default=1.0, description="Repetition penalty"), seed: int = Form(default=42, description="Random Seed"), max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text") ): set_seed(seed) if repetition_penalty == 0.0: min_penalty = 1.05 max_penalty = 1.5 repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8) prompt = f"User: {text}\nAssistant: " input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device) model.eval() print("Generating text...") print(f"max_length: {max_length}, do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, " f"temperature: {temperature}, repetition_penalty: {repetition_penalty}, penalty_alpha: {penalty_alpha}") time_start = time.time() sample_outputs = model.generate(input_ids, penalty_alpha=penalty_alpha, do_sample=do_sample, min_length=200, max_length=max_length, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, num_return_sequences=1, max_time=max_time ) result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) # result = result[len(prompt) + 1:] time_end = time.time() time_diff = time_end - time_start print(f"result:\n{result}") generated_text = result[len(prompt):] return {"generated_text": generated_text, "processing_time": time_diff} def get_text_generator(model_name: str, device: str = "cpu"): hf_auth_token = os.getenv("HF_AUTH_TOKEN", False) print(f"hf_auth_token: {hf_auth_token}") print(f"Loading model with device: {device}...") tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token) model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token) model.to(device) print("Model loaded") return model, tokenizer def get_config(): return json.load(open("config.json", "r")) config = get_config() device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer = get_text_generator(model_name=config["model_name"], device=device)