Spaces:
Running
Running
Commit
•
9119567
1
Parent(s):
218a3ef
Update app.py
Browse files
app.py
CHANGED
@@ -1,411 +1,27 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
from tqdm.notebook import tqdm
|
7 |
-
|
8 |
-
import PIL.Image
|
9 |
-
|
10 |
-
import transformers
|
11 |
-
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, TrainingArguments, Trainer
|
12 |
-
|
13 |
-
import datasets
|
14 |
-
from datasets import load_dataset, Features, Array3D, DatasetDict, ClassLabel
|
15 |
-
|
16 |
-
import torch
|
17 |
-
|
18 |
-
from torchvision.transforms import (
|
19 |
-
CenterCrop,
|
20 |
-
Compose,
|
21 |
-
Normalize,
|
22 |
-
RandomHorizontalFlip,
|
23 |
-
RandomResizedCrop,
|
24 |
-
Resize,
|
25 |
-
ToTensor,
|
26 |
-
)
|
27 |
-
|
28 |
-
import evaluate
|
29 |
-
|
30 |
import streamlit as st
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
token = st.sidebar.text_input("Enter your Hugging Face token:", type="password")
|
35 |
-
|
36 |
-
logged = False
|
37 |
-
path_set = False
|
38 |
-
|
39 |
-
if st.sidebar.button("Login"):
|
40 |
-
with open("token.txt", "w") as f:
|
41 |
-
f.write(token)
|
42 |
-
|
43 |
-
# Intentar iniciar sesión con huggingface-cli
|
44 |
-
with st.spinner("Logging in..."):
|
45 |
-
exit_code = os.system(f"huggingface-cli login --token {token}")
|
46 |
-
|
47 |
-
if exit_code != 0:
|
48 |
-
st.sidebar.error("Login failed. Please check your token and try again.")
|
49 |
-
else:
|
50 |
-
st.sidebar.success("Logged in successfully!")
|
51 |
-
logged = True
|
52 |
-
|
53 |
-
labels = ["CM05",
|
54 |
-
"FACTURA",
|
55 |
-
"advertisement",
|
56 |
-
"handwritten",
|
57 |
-
"scientific_report",
|
58 |
-
"budget",
|
59 |
-
"scientific_publication",
|
60 |
-
"presentation",
|
61 |
-
"file_folder",
|
62 |
-
"memo",
|
63 |
-
"resume",
|
64 |
-
"invoice",
|
65 |
-
"letter",
|
66 |
-
"questionnaire",
|
67 |
-
"form",
|
68 |
-
"news_article"]
|
69 |
-
|
70 |
-
NUM_OF_LABELS = len(labels)
|
71 |
-
|
72 |
-
label2id, id2label = dict(), dict()
|
73 |
-
|
74 |
-
for i, label in enumerate(labels):
|
75 |
-
label2id[label] = i
|
76 |
-
id2label[i] = label
|
77 |
-
|
78 |
-
st.title("Document AI")
|
79 |
-
|
80 |
-
parent_dir = st.text_input("Enter the directory path:")
|
81 |
-
#parent_dir = "/content/docs"
|
82 |
-
#parent_dir = r"C:\Users\Windows\Documents\AI Ollama\docs\docs"
|
83 |
-
|
84 |
-
subfolders = ['CM05', 'FACTURA']
|
85 |
-
|
86 |
-
selected_subfolder = st.sidebar.selectbox("Selecciona la subcarpeta", subfolders)
|
87 |
-
|
88 |
-
all_files_loaded = False
|
89 |
-
|
90 |
-
if parent_dir:
|
91 |
-
if not os.path.exists(parent_dir):
|
92 |
-
st.error(f"The directory {parent_dir} does not exist.")
|
93 |
-
else:
|
94 |
-
path_set = True
|
95 |
-
st.success("Directory path set successfully.")
|
96 |
-
uploaded_files = st.sidebar.file_uploader("Subir archivos", type=['jpg'], accept_multiple_files=True)
|
97 |
-
|
98 |
-
if path_set:
|
99 |
-
if st.sidebar.button("Cargar"):
|
100 |
-
if uploaded_files:
|
101 |
-
for file in uploaded_files:
|
102 |
-
# Obtener el nombre del archivo y la extensión
|
103 |
-
filename = file.name
|
104 |
-
file_extension = filename.split(".")[-1].lower()
|
105 |
-
|
106 |
-
subfolder_path = os.path.join(parent_dir, selected_subfolder)
|
107 |
-
|
108 |
-
os.makedirs(subfolder_path, exist_ok=True)
|
109 |
-
|
110 |
-
existing_files = os.listdir(subfolder_path)
|
111 |
-
file_count = len(existing_files)
|
112 |
-
|
113 |
-
new_filename = f"{selected_subfolder}_{file_count + 1}.{file_extension}"
|
114 |
-
|
115 |
-
file_path = os.path.join(subfolder_path, new_filename)
|
116 |
-
|
117 |
-
# unique_filename = f"uploaded_file_{hash(file.getvalue())}.{file_extension}"
|
118 |
-
|
119 |
-
# unique_filename = f"new_filename.{file_extension}"
|
120 |
-
|
121 |
-
#file_path = os.path.join(parent_dir, unique_filename)
|
122 |
-
#file_path = os.path.join(parent_dir, selected_subfolder, unique_filename)
|
123 |
-
file_path = os.path.join(subfolder_path, new_filename)
|
124 |
-
with open(file_path, "wb") as f:
|
125 |
-
f.write(file.getvalue())
|
126 |
-
st.success("Files uploaded successfully.")
|
127 |
-
|
128 |
-
if st.sidebar.button("Mostrar contenido del directorio"):
|
129 |
-
dir_contents = {}
|
130 |
-
for subfolder in subfolders:
|
131 |
-
subfolder_path = os.path.join(parent_dir, subfolder)
|
132 |
-
if os.path.exists(subfolder_path):
|
133 |
-
dir_contents[subfolder] = os.listdir(subfolder_path)
|
134 |
-
else:
|
135 |
-
dir_contents[subfolder] = []
|
136 |
-
|
137 |
-
st.sidebar.write("Contenido actual del directorio:")
|
138 |
-
st.sidebar.json(dir_contents)
|
139 |
-
|
140 |
-
if st.sidebar.button("Cargo Todos Los Archivos"):
|
141 |
-
all_files_loaded = True
|
142 |
-
|
143 |
-
if all_files_loaded:
|
144 |
-
all_files = glob.glob(os.path.join(parent_dir, "**"), recursive=True)
|
145 |
-
|
146 |
-
dir_path = os.path.join(parent_dir, "*", "*.jpg")
|
147 |
-
|
148 |
-
files_and_name = glob.glob(dir_path)
|
149 |
-
|
150 |
-
st.write(f"Files found: {files_and_name}")
|
151 |
-
|
152 |
-
metadata = pd.DataFrame(files_and_name, columns=["file_path"])
|
153 |
-
|
154 |
-
metadata['file_name'] = metadata['file_path'].apply(lambda x: x.split("/")[-2] + "/" + x.split("/")[-1])
|
155 |
-
|
156 |
-
metadata['label'] = metadata['file_path'].apply(lambda x: x.split("/")[-2])
|
157 |
-
metadata['label'].replace(label2id, inplace=True)
|
158 |
-
|
159 |
-
metadata = metadata.drop(columns=["file_path"])
|
160 |
-
|
161 |
-
metadata_file_location = os.path.join(parent_dir, "metadata.csv")
|
162 |
-
metadata.to_csv(metadata_file_location, index=False)
|
163 |
-
|
164 |
-
st.write(metadata.tail())
|
165 |
-
|
166 |
-
dataset = load_dataset(parent_dir)
|
167 |
-
|
168 |
-
dataset = dataset.cast_column("label", ClassLabel(names=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]))
|
169 |
-
|
170 |
-
dataset
|
171 |
-
|
172 |
-
dataset['train'][1]
|
173 |
-
|
174 |
-
train_split = dataset['train'].train_test_split(train_size=0.80)
|
175 |
-
|
176 |
-
ds = DatasetDict({
|
177 |
-
'train' : train_split['train'],
|
178 |
-
'eval' : train_split['test']
|
179 |
-
})
|
180 |
-
|
181 |
-
st.title("Document AI Model Configuration")
|
182 |
-
|
183 |
-
#MODEL_VERSION = st.text_input("Enter the model version:")
|
184 |
-
|
185 |
-
MODEL_CKPT = "microsoft/dit-base"
|
186 |
-
MODEL_NAME = "Classifier_CM05-v2"
|
187 |
-
#MODEL_NAME = MODEL_CKPT.split("/")[-1] + "-Classifier_CM05" + "-" + MODEL_VERSION
|
188 |
-
|
189 |
-
# if MODEL_VERSION:
|
190 |
-
# st.write(f"Model version: {MODEL_VERSION}")
|
191 |
-
|
192 |
-
MODEL_CKPT
|
193 |
-
|
194 |
-
NUM_OF_EPOCHS=18
|
195 |
-
LEARNING_RATE=5e-5
|
196 |
-
|
197 |
-
BATCH_SIZE=32
|
198 |
-
DEVICE = torch.device("cuda")
|
199 |
-
|
200 |
-
REPORTS_TO='tensorboard'
|
201 |
-
STRATEGY = "epoch"
|
202 |
-
|
203 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_CKPT)
|
204 |
-
feature_extractor
|
205 |
-
|
206 |
-
# normalize
|
207 |
-
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
208 |
-
|
209 |
-
# train_transforms
|
210 |
-
train_transforms = Compose(
|
211 |
-
[
|
212 |
-
RandomResizedCrop((feature_extractor.size['height'], feature_extractor.size['width'])),# it was a list that used [], not ()
|
213 |
-
RandomHorizontalFlip(),
|
214 |
-
ToTensor(),
|
215 |
-
normalize
|
216 |
-
]
|
217 |
-
)
|
218 |
-
|
219 |
-
# eval_transforms
|
220 |
-
val_transforms = Compose(
|
221 |
-
[
|
222 |
-
Resize((feature_extractor.size['height'], feature_extractor.size['width'])),
|
223 |
-
CenterCrop((feature_extractor.size['height'], feature_extractor.size['width'])),
|
224 |
-
ToTensor(),
|
225 |
-
normalize,
|
226 |
-
]
|
227 |
-
)
|
228 |
-
|
229 |
-
def preprocess_train(example_batch):
|
230 |
-
"""
|
231 |
-
Apply train_transforms across a batch
|
232 |
-
"""
|
233 |
-
example_batch["pixel_values"] = [
|
234 |
-
train_transforms(image.convert("RGB")) for image in example_batch["image"]
|
235 |
-
]
|
236 |
-
return example_batch
|
237 |
-
|
238 |
-
def preprocess_val(example_batch):
|
239 |
-
"""
|
240 |
-
Apply val_transforms across a batch
|
241 |
-
"""
|
242 |
-
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
243 |
-
return example_batch
|
244 |
-
|
245 |
-
ds['train'].set_transform(preprocess_train)
|
246 |
-
ds['eval'].set_transform(preprocess_val)
|
247 |
-
|
248 |
-
ds['train'][0]
|
249 |
-
|
250 |
-
model = AutoModelForImageClassification.from_pretrained(
|
251 |
-
MODEL_CKPT,
|
252 |
-
label2id=label2id,
|
253 |
-
id2label=id2label,
|
254 |
-
ignore_mismatched_sizes=True,
|
255 |
-
).to(DEVICE)
|
256 |
-
|
257 |
-
from transformers import TrainingArguments
|
258 |
-
|
259 |
-
args = TrainingArguments(
|
260 |
-
MODEL_NAME,
|
261 |
-
remove_unused_columns=False,
|
262 |
-
evaluation_strategy=STRATEGY,
|
263 |
-
save_strategy=STRATEGY,
|
264 |
-
logging_strategy="steps",
|
265 |
-
logging_steps=8,
|
266 |
-
logging_first_step=True,
|
267 |
-
learning_rate=LEARNING_RATE,
|
268 |
-
per_device_train_batch_size=BATCH_SIZE,
|
269 |
-
per_device_eval_batch_size=BATCH_SIZE,
|
270 |
-
gradient_accumulation_steps=4,
|
271 |
-
num_train_epochs=NUM_OF_EPOCHS,
|
272 |
-
warmup_ratio=0.10,
|
273 |
-
report_to=REPORTS_TO,
|
274 |
-
hub_private_repo=True,
|
275 |
-
push_to_hub=True
|
276 |
-
)
|
277 |
-
|
278 |
-
def compute_metrics(p):
|
279 |
-
accuracy_metric = evaluate.load("accuracy")
|
280 |
-
accuracy = accuracy_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
281 |
-
references=p.label_ids)['accuracy']
|
282 |
-
|
283 |
-
### ------------------- F1 scores -------------------
|
284 |
-
|
285 |
-
f1_score_metric = evaluate.load("f1")
|
286 |
-
weighted_f1_score = f1_score_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
287 |
-
references=p.label_ids,
|
288 |
-
average='weighted')["f1"]
|
289 |
-
|
290 |
-
micro_f1_score = f1_score_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
291 |
-
references=p.label_ids,
|
292 |
-
average='micro')['f1']
|
293 |
-
|
294 |
-
macro_f1_score = f1_score_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
295 |
-
references=p.label_ids,
|
296 |
-
average='macro')["f1"]
|
297 |
-
|
298 |
-
### ------------------- recall -------------------
|
299 |
-
|
300 |
-
recall_metric = evaluate.load("recall")
|
301 |
-
weighted_recall = recall_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
302 |
-
references=p.label_ids,
|
303 |
-
average='weighted')["recall"]
|
304 |
-
|
305 |
-
micro_recall = recall_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
306 |
-
references=p.label_ids,
|
307 |
-
average='micro')["recall"]
|
308 |
-
|
309 |
-
macro_recall = recall_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
310 |
-
references=p.label_ids,
|
311 |
-
average='macro')["recall"]
|
312 |
-
|
313 |
-
### ------------------- precision -------------------
|
314 |
-
|
315 |
-
precision_metric = evaluate.load("precision")
|
316 |
-
weighted_precision = precision_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
317 |
-
references=p.label_ids,
|
318 |
-
average='weighted')["precision"]
|
319 |
-
|
320 |
-
micro_precision = precision_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
321 |
-
references=p.label_ids,
|
322 |
-
average='micro')["precision"]
|
323 |
-
|
324 |
-
macro_precision = precision_metric.compute(predictions=np.argmax(p.predictions, axis=1),
|
325 |
-
references=p.label_ids,
|
326 |
-
average='macro')["precision"]
|
327 |
-
|
328 |
-
return {"accuracy" : accuracy,
|
329 |
-
"Weighted F1" : weighted_f1_score,
|
330 |
-
"Micro F1" : micro_f1_score,
|
331 |
-
"Macro F1" : macro_f1_score,
|
332 |
-
"Weighted Recall" : weighted_recall,
|
333 |
-
"Micro Recall" : micro_recall,
|
334 |
-
"Macro Recall" : macro_recall,
|
335 |
-
"Weighted Precision" : weighted_precision,
|
336 |
-
"Micro Precision" : micro_precision,
|
337 |
-
"Macro Precision" : macro_precision
|
338 |
-
}
|
339 |
-
|
340 |
-
def collate_fn(examples):
|
341 |
-
pixel_values = torch.stack([example['pixel_values'] for example in examples])
|
342 |
-
labels = torch.tensor([example["label"] for example in examples])
|
343 |
-
return {"pixel_values": pixel_values, "labels": labels}
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
st.write("Model Configuration")
|
348 |
-
st.write(model)
|
349 |
|
350 |
-
|
351 |
-
st.write(args)
|
352 |
|
353 |
-
|
354 |
-
st.write(ds)
|
355 |
|
356 |
-
state = st.session_state.get('state', {'training': False, 'pushing': False})
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
state['training'] = True # Activar el estado de entrenamiento
|
361 |
-
with st.spinner("Training..."):
|
362 |
-
st.write("Training...")
|
363 |
-
trainer = Trainer(
|
364 |
-
model=model,
|
365 |
-
args=args,
|
366 |
-
train_dataset=ds['train'],
|
367 |
-
eval_dataset=ds['eval'],
|
368 |
-
tokenizer=feature_extractor,
|
369 |
-
compute_metrics=compute_metrics,
|
370 |
-
data_collator=collate_fn
|
371 |
-
)
|
372 |
-
train_results = trainer.train()
|
373 |
-
st.write(train_results)
|
374 |
-
st.success("Training completed!")
|
375 |
-
state['training'] = False # Desactivar el estado de entrenamiento
|
376 |
|
377 |
-
|
378 |
-
trainer.save_model()
|
379 |
-
trainer.log_metrics("train", train_results.metrics)
|
380 |
-
trainer.save_metrics("train", train_results.metrics)
|
381 |
-
trainer.save_state()
|
382 |
|
383 |
-
|
384 |
-
state['pushing'] = True # Activar el estado de envío
|
385 |
-
with st.spinner("Pushing to Hub..."):
|
386 |
-
try:
|
387 |
-
# Ensure trainer is initialized if not done earlier
|
388 |
-
if 'trainer' not in locals():
|
389 |
-
trainer = Trainer(
|
390 |
-
model=model,
|
391 |
-
args=args,
|
392 |
-
train_dataset=ds['train'],
|
393 |
-
eval_dataset=ds['eval'],
|
394 |
-
tokenizer=feature_extractor,
|
395 |
-
compute_metrics=compute_metrics,
|
396 |
-
data_collator=collate_fn
|
397 |
-
)
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
state['pushing'] = False # Desactivar el estado de envío
|
402 |
-
except Exception as push_error:
|
403 |
-
st.error(f"Error pushing model to hub: {str(push_error)}")
|
404 |
|
405 |
-
|
406 |
-
|
407 |
|
408 |
-
|
409 |
-
st.spinner("Training...").spinner_container.markdown('')
|
410 |
-
if state['pushing']:
|
411 |
-
st.spinner("Pushing to Hub...").spinner_container.markdown('')
|
|
|
1 |
+
# Load model directly
|
2 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
3 |
+
from PIL import Image # Import the Image module
|
4 |
+
import torch # Import the torch module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import streamlit as st
|
6 |
|
7 |
+
st.title("Image Classification")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "png"])
|
|
|
10 |
|
11 |
+
#image_path = "/content/cm5_2.jpg" # Store the path as a string
|
|
|
12 |
|
|
|
13 |
|
14 |
+
processor = AutoImageProcessor.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")
|
15 |
+
model = AutoModelForImageClassification.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
image = Image.open(uploaded_file) # Load the image from the file path
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
inputs = processor(image, return_tensors="pt") # Pass the image object to the processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
with torch.no_grad(): # Use torch.no_grad() to disable gradient calculations
|
22 |
+
logits = model(**inputs).logits
|
|
|
|
|
|
|
23 |
|
24 |
+
# model predicts one of the 1000 ImageNet classes
|
25 |
+
predicted_label = logits.argmax(-1).item()
|
26 |
|
27 |
+
print(model.config.id2label[predicted_label])
|
|
|
|
|
|