|
import os
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
import torch
|
|
import traceback
|
|
|
|
MERGED_MODEL_PATH = "./merged_tinyllama_logger"
|
|
|
|
SAMPLE_LOG = """2023-03-06 15:38:41 ERROR [Worker-11] org.hibernate.exception.ConstraintViolationException at at com.example.CacheManager.land(CacheManager.java:359) at at com.example.ShippingService.discover(CacheManager.java:436) at at com.example.HttpClient.work(DatabaseConnector.java:494) at at com.example.ShippingService.window(OrderModule.java:378) at at com.example.CacheManager.almost(DatabaseConnector.java:326) at at com.example.DatabaseConnector.couple(AuthModule.java:13) at at com.example.PaymentModule.wrong(HttpClient.java:244)."""
|
|
|
|
try:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MERGED_MODEL_PATH,
|
|
low_cpu_mem_usage= True,
|
|
return_dict = True,
|
|
torch_dtype = torch.float16,
|
|
device_map = "auto"
|
|
)
|
|
print("AutoModelForCausalLM loaded successfully.")
|
|
print("Loading AutoTokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH)
|
|
print("AutoTokenizer loaded successfully.")
|
|
except Exception as e:
|
|
print("ERROR LOADING MODEL OR TOKENIZER...CHECK PATH")
|
|
traceback.print_exc()
|
|
|
|
if tokenizer is None:
|
|
print("error loading tokenizer")
|
|
exit(1)
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "left"
|
|
|
|
prompt = SAMPLE_LOG + "\n"
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(model.device)
|
|
|
|
with torch.no_grad():
|
|
output_tokens = model.generate(
|
|
**inputs,
|
|
max_new_tokens=60,
|
|
temperature=0.3,
|
|
do_sample=True,
|
|
top_p=0.9,
|
|
top_k=30,
|
|
eos_token_id = tokenizer.eos_token_id,
|
|
pad_token_id = tokenizer.pad_token_id,
|
|
num_return_sequences = 1
|
|
)
|
|
|
|
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
|
print(f"Generated Text: {generated_text}")
|
|
print("END OF GENERATED TEXT")
|
|
|
|
|
|
|
|
summary_start_index = len(SAMPLE_LOG) + 1
|
|
summary = ""
|
|
|
|
if "PM" in generated_text:
|
|
summary_end_index = generated_text.rfind("PM") + len("PM")
|
|
elif "AM" in generated_text:
|
|
summary_end_index = generated_text.rfind("AM") + len("AM")
|
|
|
|
if summary_end_index != -1 and summary_end_index > summary_start_index:
|
|
summary = generated_text[len(SAMPLE_LOG)+1:summary_end_index].strip()
|
|
else:
|
|
prompt_end_index = generated_text.find(SAMPLE_LOG + "\n")
|
|
if prompt_end_index != -1:
|
|
summary = generated_text[prompt_end_index + len(SAMPLE_LOG + "\n"):].strip()
|
|
else:
|
|
summary = generated_text.strip()
|
|
|
|
print(summary) |