lele-cecere's picture
switched to non quantized mistral instruct
809d04b
raw
history blame contribute delete
No virus
6.35 kB
import os
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.responses import JSONResponse
from transformers import (
AutoTokenizer,
LlamaForCausalLM,
AutoModelForCausalLM,
AutoConfig,
)
from accelerate import infer_auto_device_map, init_empty_weights
import time
import torch
import transformers
from pynvml import *
import pynvml
import logging
from examples_metadata import (
bomber_shorten_example,
bomber_format_example,
shirt_format_example,
pants_format_example,
bag_metadata,
dress_example,
bomber_example,
)
logging.basicConfig(level=logging.DEBUG)
os.system("pip list")
#print Cuda version
print("Gpu Cuda Driver Version:")
nvmlInit()
print(pynvml.nvmlSystemGetCudaDriverVersion())
cuda_version = torch.version.cuda
print("PyTorch is using CUDA Version:", cuda_version)
# Check if CUDA is available
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = torch.device("cuda")
else:
print("CUDA is not available. Using CPU.")
torch.zeros(1).cuda()
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used//1024**2} MB.")
def print_summary(result):
print(f"Time: {result.metrics['train_runtime']:.2f}")
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
print_gpu_utilization()
def remove_before_word(text, word):
parts = text.split(word)
# Join the parts after the word if the word is found, otherwise return the original text
return "".join(parts[1:]) if word in text else text
def generateMods(
generator,
metadata,
example1="",
example2="",
):
prompt = f"""
Given an fashion item metadata, propose 5 possible variation of the item.
## Examples
{example1}
{example2}
#### Input
Metadata: {metadata};
#### Output
"""
prompt_template = f"""<s> [INST]
{prompt}
[/INST]
"""
print("before inference")
print_gpu_utilization()
with torch.no_grad():
res = generator(prompt_template)
return res[0]["generated_text"]
def shortenMods(generator, res):
prompt = f"""
Given some possible variations of a fashion item, format the answer in this way:
For every variation, reduce them in under 4 words.
Use user-friendly terms.
Example:
{bomber_shorten_example}
Input:
{res}
Output:
"""
prompt_template = f"""<s> [INST]
{prompt}
[/INST]
"""
print("before inference")
print_gpu_utilization()
with torch.no_grad():
res = generator(prompt_template)
# print(res)
return res[0]["generated_text"]
def formatMods(generator, res):
prompt = f'''
I have a list like:
-var 1
-var 2
-var 3...
Rewrite the list and put it in square brackets
[var1, var2, var3, ...]
no code, just the list
It must begin with "[" and end with "]".
Examples:
{bomber_format_example}
{shirt_format_example}
{pants_format_example}
You:
Input:
{res}
Output:
'''
prompt_template=f'''<s> [INST]
{prompt}
[/INST]
'''
print("before inference")
print_gpu_utilization()
res = generator(prompt_template)
#print(res)
return res[0]['generated_text']
def initModel(model_name_or_path, revision):
config = AutoConfig.from_pretrained(model_name_or_path)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
device_map = infer_auto_device_map(model)
print(device_map)
print("Loading the model")
start = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=False,
device_map="auto",
offload_folder="offload",
offload_state_dict=True,
revision=revision,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
print("type of model:", type(model))
print("model loaded")
print_gpu_utilization()
generator = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1,
)
print("type of model:", type(generator))
end = time.time()
print("Time spent for model loading:", end - start)
print("model initialized, generator created")
print_gpu_utilization()
return generator
def generateTags():
start = time.time()
res = generateMods(generator, bag_metadata, dress_example, bomber_example)
stripped_res = remove_before_word(res, "[/INST]")
print("generation mods response:")
print(res)
shorten_res = shortenMods(generator, stripped_res)
print("shortened response:")
print(shorten_res)
shorten_res = remove_before_word(shorten_res, "[/INST]")
formatted_res = formatMods(generator, shorten_res)
print("formatted response:")
print(formatted_res)
formatted_res = remove_before_word(formatted_res, "[/INST]")
end = time.time()
print("time spent for generating tags:", end - start)
return {"response": stripped_res, "shortened response:": shorten_res, "formatted response": formatted_res}
app = FastAPI()
@app.on_event("startup")
def load_model():
global generator
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
revision = "main"
generator = initModel(model_name_or_path, revision)
print("Model loaded")
@app.get("/")
def read_root():
return {"Hello": "World"}
# TODO: add async
@app.get("/generateTags")
def testAPI():
try:
res = generateTags()
return {"response": res}
except Exception as e:
raise HTTPException(
status_code=500, detail=str(e) #not safe, avoid giving not needed details to the client
)
@app.exception_handler(Exception)
async def custom_exception_handler(request, exc):
logging.error(f"Error occurred: {exc}")
return JSONResponse(
status_code=500, content={"message": "An internal server error occurred."}
)