sketch2pnml / endpoints /converter.py
sam0ed
Initial commit with Git LFS support
8eb0b3e
from pipeline.workflow import recognize_graph
from pipeline.commons import here
import os
import jinja2
import pickle
import pm4py
from pm4py.visualization.petri_net import visualizer as pn_visualizer
from collections import Counter, defaultdict
from typing import List, Tuple
from pipeline.models import Place, Transition, Arc
from groq import Groq
import os
import base64
from config.path_config import (
OUTPUT_DIR, TEMPLATES_DIR, VISUALIZATIONS_DIR, PIPELINE_OUTPUT_DIR,
PLACES_PKL_PATH, TRANSITIONS_PKL_PATH, ARCS_PKL_PATH,
PLACES_FIXED_PKL_PATH, TRANSITIONS_FIXED_PKL_PATH, ARCS_FIXED_PKL_PATH,
OUTPUT_PNML_PATH, OUTPUT_PETRIOBJ_PATH, OUTPUT_JSON_PATH, OUTPUT_PNG_PATH, OUTPUT_GV_PATH,
WORKING_IMAGE_PATH, ensure_directories_exist, get_visualization_path, get_output_file_path
)
def process_elements(places: List[Place], transitions: List[Transition], arcs: List[Arc]) -> Tuple[List[Place], List[Transition], List[Arc]]:
# Process places to remove those with no connected arcs
places_to_remove = set()
for place in places:
if not any(arc.source == place or arc.target == place for arc in arcs):
places_to_remove.add(place)
new_places = [p for p in places if p not in places_to_remove]
# Remove arcs connected to removed places
arcs_after_places = [arc for arc in arcs if arc.source not in places_to_remove and arc.target not in places_to_remove]
# Process transitions to remove those with less than two connected arcs
transitions_to_remove = set()
arcs_to_remove = set()
for transition in transitions:
connected_arcs = [arc for arc in arcs_after_places if arc.source == transition or arc.target == transition]
if len(connected_arcs) < 2:
transitions_to_remove.add(transition)
arcs_to_remove.update(connected_arcs)
new_transitions = [t for t in transitions if t not in transitions_to_remove]
arcs_after_transitions = [arc for arc in arcs_after_places if arc not in arcs_to_remove]
# Adjust transitions to have both incoming and outgoing arcs
for transition in new_transitions:
connected_arcs = [arc for arc in arcs_after_transitions if arc.source == transition or arc.target == transition]
outgoing = sum(1 for arc in connected_arcs if arc.source == transition)
incoming = sum(1 for arc in connected_arcs if arc.target == transition)
if outgoing == 0 and incoming >= 1:
# Flip one incoming arc to outgoing
for arc in connected_arcs:
if arc.target == transition:
arc.source, arc.target = arc.target, arc.source
break
elif incoming == 0 and outgoing >= 1:
# Flip one outgoing arc to incoming
for arc in connected_arcs:
if arc.source == transition:
arc.source, arc.target = arc.target, arc.source
break
return new_places, new_transitions, arcs_after_transitions
def fix_petri_net():
"""Method that checks for all the errors in the petri net, logs the errors and applies fixes, if readily available."""
with open(PLACES_PKL_PATH, "rb") as f:
places = pickle.load(f)
with open(TRANSITIONS_PKL_PATH, "rb") as f:
transitions = pickle.load(f)
with open(ARCS_PKL_PATH, "rb") as f:
arcs = pickle.load(f)
### Remove duplicate ids across places, transitions and arcs
all_ids = []
all_ids.extend(place.id for place in places)
all_ids.extend(transition.id for transition in transitions)
all_ids.extend(arc.id for arc in arcs)
id_duplicates = [id for id, count in Counter(all_ids).items() if count > 1]
for duplicate_id in id_duplicates:
duplicate_elements = []
duplicate_elements.extend([place for place in places if place.id == duplicate_id])
duplicate_elements.extend([transition for transition in transitions if transition.id == duplicate_id])
duplicate_elements.extend([arc for arc in arcs if arc.id == duplicate_id])
print(f"Duplicate ID {duplicate_id} found in elements: {duplicate_elements}")
### Remove cycles, remove same type connections
arcs = [arc for arc in arcs if type(arc.source) != type(arc.target)]
### Fix weights if any are less than 1
for arc in arcs:
if arc.weight < 1:
print(f"Arc found with weight less than 1: {arc}")
print("Applying fix to set the weight to 1")
arc.weight = 1
### find arcs in arcs list, that have same source and the same target, and merge them into one arc, with the sum of the weights
# Group arcs by their source and target
arc_groups = defaultdict(list)
for arc in arcs: # Create a copy of the list to safely modify original
key = (arc.source.id, arc.target.id)
arc_groups[key].append(arc)
# For each group of arcs with same source/target, merge them
for (source_id, target_id), group in arc_groups.items():
if len(group) > 1:
print(f"Found {len(group)} parallel arcs between same source and target: {source_id} -> {target_id}")
total_weight = sum(arc.weight for arc in group)
merged_arc = group[0]
merged_arc.weight = total_weight
# Remove other arcs from the original list
for arc in group[1:]:
if arc in arcs:
arcs.remove(arc)
### There should be no hanging places, every place must have at least one arc
### Every transition must have at least one input and one output arc
# places, transitions, arcs = sanitize_petri_net(places, transitions, arcs)
places, transitions, arcs = process_elements(places, transitions, arcs)
print(f"len(places): {len(places)}")
print(f"len(transitions): {len(transitions)}")
print(f"len(arcs): {len(arcs)}")
### save results as pickles
ensure_directories_exist()
with open(PLACES_FIXED_PKL_PATH, "wb") as f:
pickle.dump(places, f)
with open(TRANSITIONS_FIXED_PKL_PATH, "wb") as f:
pickle.dump(transitions, f)
with open(ARCS_FIXED_PKL_PATH, "wb") as f:
pickle.dump(arcs, f)
def run_and_save_pipeline(config_path: str, image_path: str):
result = recognize_graph(image_path, config_path)
# Access the results
places = result["places"]
transitions = result["transitions"]
arcs = result["arcs"]
# Saving logic
ensure_directories_exist()
for name, img in result["visualizations"].items():
img.save(get_visualization_path(name))
with open(PLACES_PKL_PATH, "wb") as f:
pickle.dump(places, f)
with open(TRANSITIONS_PKL_PATH, "wb") as f:
pickle.dump(transitions, f)
with open(ARCS_PKL_PATH, "wb") as f:
pickle.dump(arcs, f)
print(f"Recognition complete. Found {len(places)} places, {len(transitions)} transitions, and {len(arcs)} arcs.")
def render_diagram_to(file_type: str):
"""Method that renders elements into the final pnml string
Parameters
----------
type: str
the type of template to use
Returns
-------
str
The string representing the final pnml model
"""
with open(PLACES_FIXED_PKL_PATH, "rb") as f:
places = pickle.load(f)
with open(TRANSITIONS_FIXED_PKL_PATH, "rb") as f:
transitions = pickle.load(f)
with open(ARCS_FIXED_PKL_PATH, "rb") as f:
arcs = pickle.load(f)
template_loader = jinja2.FileSystemLoader(
searchpath=TEMPLATES_DIR
)
template_env = jinja2.Environment(loader=template_loader)
if file_type == "pnml":
template = template_env.get_template(f"template.{file_type}.jinja")
output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs})
output_file_path = get_output_file_path(f"output.{file_type}")
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(output_text)
elif file_type == "petriobj":
template = template_env.get_template(f"template.{file_type}.jinja")
place_to_index = {place.id: index for index, place in enumerate(places)}
transition_to_index = {transition.id: index for index, transition in enumerate(transitions)}
output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs, "place_to_index": place_to_index, "transition_to_index": transition_to_index})
output_file_path = get_output_file_path(f"output.{file_type}")
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(output_text)
else:
raise ValueError(f"Invalid file type: {file_type}")
return output_text
def render_to_graphviz():
net, im, fm = pm4py.read_pnml(OUTPUT_PNML_PATH)
gviz = pn_visualizer.apply(net, im, fm)
pn_visualizer.save(gviz, OUTPUT_GV_PATH)
pm4py.save_vis_petri_net(net, im, fm, OUTPUT_PNG_PATH)
def render_to_json():
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
image_path = WORKING_IMAGE_PATH
base64_image = encode_image(image_path)
# Load API key from environment variable
groq_api_key = os.getenv('GROQ_API_KEY')
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please check your .env file.")
client = Groq(api_key=groq_api_key)
completion = client.chat.completions.create(
model="meta-llama/llama-4-scout-17b-16e-instruct",
messages=[
{
"role": "system",
"content": "{\n \"places\": [\n {\"id\": \"string\", \"tokens\": \"integer\"}\n ],\n \"transitions\": [\n {\"id\": \"string\", \"delay\": \"number_or_string\"}\n ],\n \"arcs\": [\n {\"source\": \"string\", \"target\": \"string\", \"weight\": \"integer\"}\n ]\n}"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Take image of a Petri net as input and provide the textual representation of the graph in json format, according to this json template."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
}
]
}
],
temperature=1,
max_completion_tokens=2048,
top_p=1,
stream=False,
response_format={"type": "json_object"},
stop=None,
)
output_file_path = OUTPUT_JSON_PATH
with open(output_file_path, "w", encoding="utf-8") as f:
f.write(completion.choices[0].message.content)
if __name__ == "__main__":
# run_and_save_pipeline(config_path=here("../data/config.yaml"), image_path=here("../data/local/mid_petri_2.png"))
fix_petri_net()
## the next steps should be done in parallel
render_diagram_to("pnml")
render_diagram_to("petriobj")
render_to_graphviz()
render_to_json()