Spaces:
Sleeping
Sleeping
File size: 11,321 Bytes
8eb0b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
from pipeline.workflow import recognize_graph
from pipeline.commons import here
import os
import jinja2
import pickle
import pm4py
from pm4py.visualization.petri_net import visualizer as pn_visualizer
from collections import Counter, defaultdict
from typing import List, Tuple
from pipeline.models import Place, Transition, Arc
from groq import Groq
import os
import base64
from config.path_config import (
OUTPUT_DIR, TEMPLATES_DIR, VISUALIZATIONS_DIR, PIPELINE_OUTPUT_DIR,
PLACES_PKL_PATH, TRANSITIONS_PKL_PATH, ARCS_PKL_PATH,
PLACES_FIXED_PKL_PATH, TRANSITIONS_FIXED_PKL_PATH, ARCS_FIXED_PKL_PATH,
OUTPUT_PNML_PATH, OUTPUT_PETRIOBJ_PATH, OUTPUT_JSON_PATH, OUTPUT_PNG_PATH, OUTPUT_GV_PATH,
WORKING_IMAGE_PATH, ensure_directories_exist, get_visualization_path, get_output_file_path
)
def process_elements(places: List[Place], transitions: List[Transition], arcs: List[Arc]) -> Tuple[List[Place], List[Transition], List[Arc]]:
# Process places to remove those with no connected arcs
places_to_remove = set()
for place in places:
if not any(arc.source == place or arc.target == place for arc in arcs):
places_to_remove.add(place)
new_places = [p for p in places if p not in places_to_remove]
# Remove arcs connected to removed places
arcs_after_places = [arc for arc in arcs if arc.source not in places_to_remove and arc.target not in places_to_remove]
# Process transitions to remove those with less than two connected arcs
transitions_to_remove = set()
arcs_to_remove = set()
for transition in transitions:
connected_arcs = [arc for arc in arcs_after_places if arc.source == transition or arc.target == transition]
if len(connected_arcs) < 2:
transitions_to_remove.add(transition)
arcs_to_remove.update(connected_arcs)
new_transitions = [t for t in transitions if t not in transitions_to_remove]
arcs_after_transitions = [arc for arc in arcs_after_places if arc not in arcs_to_remove]
# Adjust transitions to have both incoming and outgoing arcs
for transition in new_transitions:
connected_arcs = [arc for arc in arcs_after_transitions if arc.source == transition or arc.target == transition]
outgoing = sum(1 for arc in connected_arcs if arc.source == transition)
incoming = sum(1 for arc in connected_arcs if arc.target == transition)
if outgoing == 0 and incoming >= 1:
# Flip one incoming arc to outgoing
for arc in connected_arcs:
if arc.target == transition:
arc.source, arc.target = arc.target, arc.source
break
elif incoming == 0 and outgoing >= 1:
# Flip one outgoing arc to incoming
for arc in connected_arcs:
if arc.source == transition:
arc.source, arc.target = arc.target, arc.source
break
return new_places, new_transitions, arcs_after_transitions
def fix_petri_net():
"""Method that checks for all the errors in the petri net, logs the errors and applies fixes, if readily available."""
with open(PLACES_PKL_PATH, "rb") as f:
places = pickle.load(f)
with open(TRANSITIONS_PKL_PATH, "rb") as f:
transitions = pickle.load(f)
with open(ARCS_PKL_PATH, "rb") as f:
arcs = pickle.load(f)
### Remove duplicate ids across places, transitions and arcs
all_ids = []
all_ids.extend(place.id for place in places)
all_ids.extend(transition.id for transition in transitions)
all_ids.extend(arc.id for arc in arcs)
id_duplicates = [id for id, count in Counter(all_ids).items() if count > 1]
for duplicate_id in id_duplicates:
duplicate_elements = []
duplicate_elements.extend([place for place in places if place.id == duplicate_id])
duplicate_elements.extend([transition for transition in transitions if transition.id == duplicate_id])
duplicate_elements.extend([arc for arc in arcs if arc.id == duplicate_id])
print(f"Duplicate ID {duplicate_id} found in elements: {duplicate_elements}")
### Remove cycles, remove same type connections
arcs = [arc for arc in arcs if type(arc.source) != type(arc.target)]
### Fix weights if any are less than 1
for arc in arcs:
if arc.weight < 1:
print(f"Arc found with weight less than 1: {arc}")
print("Applying fix to set the weight to 1")
arc.weight = 1
### find arcs in arcs list, that have same source and the same target, and merge them into one arc, with the sum of the weights
# Group arcs by their source and target
arc_groups = defaultdict(list)
for arc in arcs: # Create a copy of the list to safely modify original
key = (arc.source.id, arc.target.id)
arc_groups[key].append(arc)
# For each group of arcs with same source/target, merge them
for (source_id, target_id), group in arc_groups.items():
if len(group) > 1:
print(f"Found {len(group)} parallel arcs between same source and target: {source_id} -> {target_id}")
total_weight = sum(arc.weight for arc in group)
merged_arc = group[0]
merged_arc.weight = total_weight
# Remove other arcs from the original list
for arc in group[1:]:
if arc in arcs:
arcs.remove(arc)
### There should be no hanging places, every place must have at least one arc
### Every transition must have at least one input and one output arc
# places, transitions, arcs = sanitize_petri_net(places, transitions, arcs)
places, transitions, arcs = process_elements(places, transitions, arcs)
print(f"len(places): {len(places)}")
print(f"len(transitions): {len(transitions)}")
print(f"len(arcs): {len(arcs)}")
### save results as pickles
ensure_directories_exist()
with open(PLACES_FIXED_PKL_PATH, "wb") as f:
pickle.dump(places, f)
with open(TRANSITIONS_FIXED_PKL_PATH, "wb") as f:
pickle.dump(transitions, f)
with open(ARCS_FIXED_PKL_PATH, "wb") as f:
pickle.dump(arcs, f)
def run_and_save_pipeline(config_path: str, image_path: str):
result = recognize_graph(image_path, config_path)
# Access the results
places = result["places"]
transitions = result["transitions"]
arcs = result["arcs"]
# Saving logic
ensure_directories_exist()
for name, img in result["visualizations"].items():
img.save(get_visualization_path(name))
with open(PLACES_PKL_PATH, "wb") as f:
pickle.dump(places, f)
with open(TRANSITIONS_PKL_PATH, "wb") as f:
pickle.dump(transitions, f)
with open(ARCS_PKL_PATH, "wb") as f:
pickle.dump(arcs, f)
print(f"Recognition complete. Found {len(places)} places, {len(transitions)} transitions, and {len(arcs)} arcs.")
def render_diagram_to(file_type: str):
"""Method that renders elements into the final pnml string
Parameters
----------
type: str
the type of template to use
Returns
-------
str
The string representing the final pnml model
"""
with open(PLACES_FIXED_PKL_PATH, "rb") as f:
places = pickle.load(f)
with open(TRANSITIONS_FIXED_PKL_PATH, "rb") as f:
transitions = pickle.load(f)
with open(ARCS_FIXED_PKL_PATH, "rb") as f:
arcs = pickle.load(f)
template_loader = jinja2.FileSystemLoader(
searchpath=TEMPLATES_DIR
)
template_env = jinja2.Environment(loader=template_loader)
if file_type == "pnml":
template = template_env.get_template(f"template.{file_type}.jinja")
output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs})
output_file_path = get_output_file_path(f"output.{file_type}")
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(output_text)
elif file_type == "petriobj":
template = template_env.get_template(f"template.{file_type}.jinja")
place_to_index = {place.id: index for index, place in enumerate(places)}
transition_to_index = {transition.id: index for index, transition in enumerate(transitions)}
output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs, "place_to_index": place_to_index, "transition_to_index": transition_to_index})
output_file_path = get_output_file_path(f"output.{file_type}")
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(output_text)
else:
raise ValueError(f"Invalid file type: {file_type}")
return output_text
def render_to_graphviz():
net, im, fm = pm4py.read_pnml(OUTPUT_PNML_PATH)
gviz = pn_visualizer.apply(net, im, fm)
pn_visualizer.save(gviz, OUTPUT_GV_PATH)
pm4py.save_vis_petri_net(net, im, fm, OUTPUT_PNG_PATH)
def render_to_json():
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
image_path = WORKING_IMAGE_PATH
base64_image = encode_image(image_path)
# Load API key from environment variable
groq_api_key = os.getenv('GROQ_API_KEY')
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please check your .env file.")
client = Groq(api_key=groq_api_key)
completion = client.chat.completions.create(
model="meta-llama/llama-4-scout-17b-16e-instruct",
messages=[
{
"role": "system",
"content": "{\n \"places\": [\n {\"id\": \"string\", \"tokens\": \"integer\"}\n ],\n \"transitions\": [\n {\"id\": \"string\", \"delay\": \"number_or_string\"}\n ],\n \"arcs\": [\n {\"source\": \"string\", \"target\": \"string\", \"weight\": \"integer\"}\n ]\n}"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Take image of a Petri net as input and provide the textual representation of the graph in json format, according to this json template."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
}
]
}
],
temperature=1,
max_completion_tokens=2048,
top_p=1,
stream=False,
response_format={"type": "json_object"},
stop=None,
)
output_file_path = OUTPUT_JSON_PATH
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(completion.choices[0].message.content)
if __name__ == "__main__":
# run_and_save_pipeline(config_path=here("../data/config.yaml"), image_path=here("../data/local/mid_petri_2.png"))
fix_petri_net()
## the next steps should be done in parallel
render_diagram_to("pnml")
render_diagram_to("petriobj")
render_to_graphviz()
render_to_json() |