import os from transformers import AutoModelForCausalLM, AutoTokenizer import torch import traceback MERGED_MODEL_PATH = "./merged_tinyllama_logger" SAMPLE_LOG = """2023-03-06 15:38:41 ERROR [Worker-11] org.hibernate.exception.ConstraintViolationException at at com.example.CacheManager.land(CacheManager.java:359) at at com.example.ShippingService.discover(CacheManager.java:436) at at com.example.HttpClient.work(DatabaseConnector.java:494) at at com.example.ShippingService.window(OrderModule.java:378) at at com.example.CacheManager.almost(DatabaseConnector.java:326) at at com.example.DatabaseConnector.couple(AuthModule.java:13) at at com.example.PaymentModule.wrong(HttpClient.java:244).""" try: model = AutoModelForCausalLM.from_pretrained( MERGED_MODEL_PATH, low_cpu_mem_usage= True, return_dict = True, torch_dtype = torch.float16, device_map = "auto" ) print("AutoModelForCausalLM loaded successfully.") print("Loading AutoTokenizer...") tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH) print("AutoTokenizer loaded successfully.") except Exception as e: print("ERROR LOADING MODEL OR TOKENIZER...CHECK PATH") traceback.print_exc() if tokenizer is None: print("error loading tokenizer") exit(1) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" prompt = SAMPLE_LOG + "\n" inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(model.device) with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=60, temperature=0.3, do_sample=True, top_p=0.9, top_k=30, eos_token_id = tokenizer.eos_token_id, pad_token_id = tokenizer.pad_token_id, num_return_sequences = 1 ) generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) print(f"Generated Text: {generated_text}") print("END OF GENERATED TEXT") #summary_start_index = generated_text.find(SAMPLE_LOG + "\n") # prompt_end_index = generated_text.rfind( summary_start_index = len(SAMPLE_LOG) + 1 summary = "" if "PM" in generated_text: summary_end_index = generated_text.rfind("PM") + len("PM") elif "AM" in generated_text: summary_end_index = generated_text.rfind("AM") + len("AM") if summary_end_index != -1 and summary_end_index > summary_start_index: summary = generated_text[len(SAMPLE_LOG)+1:summary_end_index].strip() else: prompt_end_index = generated_text.find(SAMPLE_LOG + "\n") if prompt_end_index != -1: summary = generated_text[prompt_end_index + len(SAMPLE_LOG + "\n"):].strip() else: summary = generated_text.strip() print(summary)