ariankhalfani commited on
Commit
0409d7f
·
verified ·
1 Parent(s): f72931a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def generate_script(dataset_code, task, model_size, epochs, batch_size):
4
+ # Extract the necessary information from the dataset code
5
+ api_key_match = re.search(r'api_key="(.*?)"', dataset_code)
6
+ workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code)
7
+ project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code)
8
+ version_number_match = re.search(r'version\((\d+)\)', dataset_code)
9
+
10
+ if not (api_key_match and workspace_match and project_name_match and version_number_match):
11
+ return "Error: Could not extract necessary information from the dataset code."
12
+
13
+ api_key = api_key_match.group(1)
14
+ workspace = workspace_match.group(1)
15
+ project_name = project_name_match.group(1)
16
+ version_number = int(version_number_match.group(1))
17
+
18
+ # Generate the script
19
+ script = f"""
20
+ import yaml
21
+ from ultralytics import YOLO
22
+ from roboflow import Roboflow
23
+ import logging
24
+ import re
25
+ import threading
26
+ import time
27
+ from io import StringIO
28
+
29
+ # Set up logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+ def auto_train():
34
+ log_stream = StringIO()
35
+ log_handler = logging.StreamHandler(log_stream)
36
+ log_handler.setLevel(logging.INFO)
37
+ logger.addHandler(log_handler)
38
+
39
+ try:
40
+ api_key = "{api_key}"
41
+ workspace = "{workspace}"
42
+ project_name = "{project_name}"
43
+ version_number = {version_number}
44
+
45
+ # Load the Roboflow dataset
46
+ rf = Roboflow(api_key=api_key)
47
+ project = rf.workspace(workspace).project(project_name)
48
+ version = project.version(version_number)
49
+ dataset = version.download("yolov8")
50
+
51
+ # Modify the data structure
52
+ yaml_file_path = f'{{dataset.location}}/data.yaml'
53
+ with open(yaml_file_path, 'r') as file:
54
+ data = yaml.safe_load(file)
55
+
56
+ data['val'] = '../valid/images'
57
+ data['test'] = '../test/images'
58
+ data['train'] = '../train/images'
59
+
60
+ with open(yaml_file_path, 'w') as file:
61
+ yaml.safe_dump(data, file)
62
+
63
+ # Determine the model name based on the selected size and task
64
+ model_type = "seg" if task == "Segmentation" else "cls"
65
+ model_name = f"yolov8{model_size[0]}-{model_type}.pt"
66
+
67
+ # Load and train the model
68
+ model = YOLO(model_name)
69
+ model.info()
70
+
71
+ # Function to read logs in real-time and update the Streamlit textbox
72
+ def update_logs():
73
+ while getattr(threading.currentThread(), "do_run", True):
74
+ time.sleep(1)
75
+ log_stream.seek(0)
76
+ print(log_stream.read())
77
+
78
+ # Start a thread to update logs in real-time
79
+ log_thread = threading.Thread(target=update_logs)
80
+ log_thread.start()
81
+
82
+ results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size})
83
+
84
+ # Stop the log update thread
85
+ logger.removeHandler(log_handler)
86
+ log_thread.do_run = False
87
+ log_thread.join()
88
+
89
+ # Return the result path and logs
90
+ log_stream.seek(0)
91
+ log_output = log_stream.read()
92
+ print("Results Directory:", results.results_dir)
93
+ print("Final Training Logs:", log_output)
94
+
95
+ except Exception as e:
96
+ logger.error(f"An error occurred: {{e}}")
97
+ log_stream.seek(0)
98
+ log_output = log_stream.read()
99
+ print(f"Error: {{e}}")
100
+ print(log_output)
101
+
102
+ finally:
103
+ logger.removeHandler(log_handler)
104
+
105
+ if __name__ == "__main__":
106
+ auto_train()
107
+ """
108
+ return script
109
+
110
+ # Streamlit interface
111
+ st.title("Auto Train Script Generator")
112
+ st.write("Generate a YOLOv8 training script using a Roboflow dataset")
113
+
114
+ dataset_code = st.text_input("Roboflow Dataset Code", placeholder="Paste your Roboflow dataset code here")
115
+ task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0)
116
+ model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0)
117
+ epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3)
118
+ batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0)
119
+
120
+ if st.button("Generate Script"):
121
+ script = generate_script(dataset_code, task, model_size, epochs, batch_size)
122
+ st.code(script, language="python")