Spaces:
Running
Running
import argparse | |
import os | |
import time | |
from contextlib import asynccontextmanager | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import torch | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.testclient import TestClient | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from custom_llm_inference import get_highlights_inner, get_next_token_predictions_inner | |
ml_models = {} | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--gpu", action="store_true", help="Enable GPU usage") | |
args = parser.parse_args() | |
USE_GPU = args.gpu | |
if not USE_GPU: | |
print("Running without GPU. To enable GPU, run with the --gpu flag.") | |
async def models_lifespan(app: FastAPI): | |
#model_name = 'google/gemma-1.1-7b-it' | |
#model_name = 'google/gemma-1.1-2b-it' | |
model_name = 'google/gemma-2-9b-it' | |
dtype = torch.bfloat16 if USE_GPU else torch.float16 | |
ml_models["llm"] = llm = { | |
'tokenizer': AutoTokenizer.from_pretrained(model_name), | |
'model': AutoModelForCausalLM.from_pretrained(model_name, device_map="auto" if USE_GPU else "cpu", torch_dtype=dtype) | |
} | |
print("Loaded llm with device map:") | |
print(llm['model'].hf_device_map) | |
# Print timing info for each endpoint | |
print("\nRunning endpoint tests...") | |
test_doc = "This is a test document that needs to be revised for clarity and conciseness." | |
test_prompt = "Make this more clear and concise." | |
client = TestClient(app) | |
start = time.time() | |
response = client.get("/api/highlights", | |
params={"doc": test_doc, "prompt": test_prompt}) | |
print(f"Highlights endpoint: {time.time() - start:.2f}s") | |
start = time.time() | |
response = client.get("/api/next_token", | |
params={"original_doc": test_doc, "prompt": test_prompt, "doc_in_progress": "This is"}) | |
print(f"Next token endpoint: {time.time() - start:.2f}s") | |
start = time.time() | |
response = client.get("/api/gen_revisions", | |
params={"doc": test_doc, "prompt": test_prompt, "n": 1}) | |
print(f"Gen revisions endpoint: {time.time() - start:.2f}s") | |
yield | |
# Release resources on exit | |
ml_models.clear() | |
DEBUG = os.getenv("DEBUG") or False | |
PORT = int(os.getenv("PORT") or "19570") | |
app = FastAPI(lifespan=models_lifespan) | |
origins = [ | |
"*", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def get_highlights(doc: str, prompt: Optional[str] = None, updated_doc: Optional[str] = '', k: Optional[int] = 5): | |
''' Example of using this in JavaScript: | |
let url = new URL('http://localhost:8000/api/highlights') | |
url.searchParams.append('doc', 'This is a test document. It is a test document because it is a test document.') | |
url.searchParams.append('prompt', 'Rewrite this document to be more concise.') | |
url.searchParams.append('updated_doc', 'This is a test document.') | |
let response = await fetch(url) | |
''' | |
llm = ml_models['llm'] | |
model = llm['model'] | |
tokenizer = llm['tokenizer'] | |
if prompt is None: | |
prompt = "Rewrite this document to be more concise." | |
highlights = get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k) | |
return {'highlights': highlights} | |
def get_next_token_predictions(original_doc: str, | |
prompt: str, | |
doc_in_progress: str, | |
k: Optional[int] = 5): | |
model = ml_models['llm']['model'] | |
tokenizer = ml_models['llm']['tokenizer'] | |
decoded_next_tokens, next_token_logits = get_next_token_predictions_inner( | |
model, tokenizer, original_doc, prompt, doc_in_progress, k) | |
return { | |
'next_tokens': decoded_next_tokens | |
} | |
def gen_revisions( | |
prompt: str, | |
doc: str, | |
n: Optional[int] = 5): | |
model = ml_models['llm']['model'] | |
tokenizer = ml_models['llm']['tokenizer'] | |
messages = [ | |
{ | |
"role": "user", | |
"content": f"{prompt}\n\n{doc}", | |
}, | |
] | |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
generations = model.generate( | |
tokenized_chat, num_return_sequences=n, | |
max_length=1024, do_sample=True, top_k=50, top_p=0.95, temperature=0.5, | |
return_dict_in_generate=True, output_scores=True) | |
generated_docs = tokenizer.batch_decode(generations.sequences, skip_special_tokens=True) | |
#print(generations.scores) | |
# Remove prompt text. see https://github.com/huggingface/transformers/blob/v4.46.2/src/transformers/pipelines/text_generation.py#L37 | |
prompt_length = len( | |
tokenizer.decode( | |
tokenized_chat[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
)) | |
return { | |
'revised_docs': [dict(doc_text=doc[prompt_length:]) for doc in generated_docs] | |
} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="localhost", port=PORT) | |