|
import os |
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
from fastapi.responses import JSONResponse |
|
from transformers import ( |
|
AutoTokenizer, |
|
LlamaForCausalLM, |
|
AutoModelForCausalLM, |
|
AutoConfig, |
|
) |
|
from accelerate import infer_auto_device_map, init_empty_weights |
|
import time |
|
import torch |
|
import transformers |
|
from pynvml import * |
|
import pynvml |
|
import logging |
|
from examples_metadata import ( |
|
bomber_shorten_example, |
|
bomber_format_example, |
|
shirt_format_example, |
|
pants_format_example, |
|
bag_metadata, |
|
dress_example, |
|
bomber_example, |
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
os.system("pip list") |
|
|
|
print("Gpu Cuda Driver Version:") |
|
nvmlInit() |
|
print(pynvml.nvmlSystemGetCudaDriverVersion()) |
|
cuda_version = torch.version.cuda |
|
print("PyTorch is using CUDA Version:", cuda_version) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
device = torch.device("cuda") |
|
else: |
|
print("CUDA is not available. Using CPU.") |
|
|
|
torch.zeros(1).cuda() |
|
|
|
|
|
def print_gpu_utilization(): |
|
nvmlInit() |
|
handle = nvmlDeviceGetHandleByIndex(0) |
|
info = nvmlDeviceGetMemoryInfo(handle) |
|
print(f"GPU memory occupied: {info.used//1024**2} MB.") |
|
|
|
|
|
def print_summary(result): |
|
print(f"Time: {result.metrics['train_runtime']:.2f}") |
|
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") |
|
print_gpu_utilization() |
|
|
|
|
|
def remove_before_word(text, word): |
|
parts = text.split(word) |
|
|
|
return "".join(parts[1:]) if word in text else text |
|
|
|
|
|
def generateMods( |
|
generator, |
|
metadata, |
|
example1="", |
|
example2="", |
|
): |
|
prompt = f""" |
|
Given an fashion item metadata, propose 5 possible variation of the item. |
|
|
|
## Examples |
|
{example1} |
|
{example2} |
|
|
|
|
|
#### Input |
|
Metadata: {metadata}; |
|
|
|
#### Output |
|
""" |
|
|
|
prompt_template = f"""<s> [INST] |
|
{prompt} |
|
[/INST] |
|
""" |
|
print("before inference") |
|
print_gpu_utilization() |
|
with torch.no_grad(): |
|
res = generator(prompt_template) |
|
return res[0]["generated_text"] |
|
|
|
|
|
def shortenMods(generator, res): |
|
prompt = f""" |
|
Given some possible variations of a fashion item, format the answer in this way: |
|
For every variation, reduce them in under 4 words. |
|
Use user-friendly terms. |
|
|
|
Example: |
|
{bomber_shorten_example} |
|
|
|
|
|
Input: |
|
{res} |
|
|
|
Output: |
|
|
|
""" |
|
|
|
prompt_template = f"""<s> [INST] |
|
{prompt} |
|
[/INST] |
|
""" |
|
print("before inference") |
|
print_gpu_utilization() |
|
with torch.no_grad(): |
|
res = generator(prompt_template) |
|
|
|
return res[0]["generated_text"] |
|
|
|
|
|
def formatMods(generator, res): |
|
prompt = f''' |
|
I have a list like: |
|
-var 1 |
|
-var 2 |
|
-var 3... |
|
|
|
Rewrite the list and put it in square brackets |
|
[var1, var2, var3, ...] |
|
|
|
no code, just the list |
|
It must begin with "[" and end with "]". |
|
|
|
Examples: |
|
{bomber_format_example} |
|
{shirt_format_example} |
|
{pants_format_example} |
|
|
|
You: |
|
Input: |
|
{res} |
|
|
|
Output: |
|
|
|
''' |
|
|
|
prompt_template=f'''<s> [INST] |
|
{prompt} |
|
[/INST] |
|
''' |
|
print("before inference") |
|
print_gpu_utilization() |
|
res = generator(prompt_template) |
|
|
|
return res[0]['generated_text'] |
|
|
|
|
|
def initModel(model_name_or_path, revision): |
|
config = AutoConfig.from_pretrained(model_name_or_path) |
|
with init_empty_weights(): |
|
model = AutoModelForCausalLM.from_config(config) |
|
|
|
device_map = infer_auto_device_map(model) |
|
print(device_map) |
|
|
|
print("Loading the model") |
|
start = time.time() |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
trust_remote_code=False, |
|
device_map="auto", |
|
offload_folder="offload", |
|
offload_state_dict=True, |
|
revision=revision, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
print("type of model:", type(model)) |
|
print("model loaded") |
|
print_gpu_utilization() |
|
|
|
generator = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=40, |
|
repetition_penalty=1.1, |
|
) |
|
print("type of model:", type(generator)) |
|
|
|
end = time.time() |
|
print("Time spent for model loading:", end - start) |
|
print("model initialized, generator created") |
|
print_gpu_utilization() |
|
return generator |
|
|
|
|
|
def generateTags(): |
|
start = time.time() |
|
res = generateMods(generator, bag_metadata, dress_example, bomber_example) |
|
stripped_res = remove_before_word(res, "[/INST]") |
|
print("generation mods response:") |
|
print(res) |
|
shorten_res = shortenMods(generator, stripped_res) |
|
print("shortened response:") |
|
print(shorten_res) |
|
shorten_res = remove_before_word(shorten_res, "[/INST]") |
|
formatted_res = formatMods(generator, shorten_res) |
|
print("formatted response:") |
|
print(formatted_res) |
|
formatted_res = remove_before_word(formatted_res, "[/INST]") |
|
end = time.time() |
|
print("time spent for generating tags:", end - start) |
|
return {"response": stripped_res, "shortened response:": shorten_res, "formatted response": formatted_res} |
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.on_event("startup") |
|
def load_model(): |
|
global generator |
|
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1" |
|
revision = "main" |
|
generator = initModel(model_name_or_path, revision) |
|
print("Model loaded") |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"Hello": "World"} |
|
|
|
|
|
|
|
@app.get("/generateTags") |
|
def testAPI(): |
|
try: |
|
res = generateTags() |
|
return {"response": res} |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=str(e) |
|
) |
|
|
|
|
|
@app.exception_handler(Exception) |
|
async def custom_exception_handler(request, exc): |
|
logging.error(f"Error occurred: {exc}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "An internal server error occurred."} |
|
) |
|
|