File size: 10,653 Bytes
5543304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""This file provides utilities for **batch processing** spectral data files (such as Raman spectra) for polymer classification. Its main goal is to process multiple files efficiently—either synchronously or asynchronously—using one or more machine learning models, and to collect, summarize, and export the results. It is designed for integration with a Streamlit-based UI, supporting file uploads and batch inference."""

import os
import time
import json
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
from dataclasses import dataclass, asdict
import pandas as pd
import numpy as np
import streamlit as st

from utils.preprocessing import preprocess_spectrum
from utils.multifile import parse_spectrum_data
from utils.async_inference import submit_batch_inference, wait_for_batch_completion
from core_logic import run_inference


@dataclass
class BatchProcessingResult:
    """Result from batch processing operation."""

    filename: str
    model_name: str
    prediction: int
    confidence: float
    logits: List[float]
    inference_time: float
    status: str = "success"
    error: Optional[str] = None
    ground_truth: Optional[int] = None


class BatchProcessor:
    """Handles batch processing of spectral data files."""

    def __init__(self, modality: str = "raman"):
        self.modality = modality
        self.results: List[BatchProcessingResult] = []

    def process_files_sync(

        self,

        file_data: List[Tuple[str, str]],  # (filename, content)

        model_names: List[str],

        target_len: int = 500,

    ) -> List[BatchProcessingResult]:
        """Process files synchronously."""
        results = []

        for filename, content in file_data:
            for model_name in model_names:
                try:
                    # Parse spectrum data
                    x_raw, y_raw = parse_spectrum_data(content)

                    # Preprocess
                    x_proc, y_proc = preprocess_spectrum(
                        x_raw, y_raw, modality=self.modality, target_len=target_len
                    )

                    # Run inference
                    start_time = time.time()
                    prediction, logits_list, probs, inference_time, logits = (
                        run_inference(y_proc, model_name)
                    )

                    if prediction is not None:
                        confidence = max(probs) if probs is not None else 0.0

                        result = BatchProcessingResult(
                            filename=filename,
                            model_name=model_name,
                            prediction=int(prediction),
                            confidence=confidence,
                            logits=logits_list or [],
                            inference_time=inference_time or 0.0,
                            ground_truth=self._extract_ground_truth(filename),
                        )
                    else:
                        result = BatchProcessingResult(
                            filename=filename,
                            model_name=model_name,
                            prediction=-1,
                            confidence=0.0,
                            logits=[],
                            inference_time=0.0,
                            status="failed",
                            error="Inference failed",
                        )

                    results.append(result)

                except Exception as e:
                    result = BatchProcessingResult(
                        filename=filename,
                        model_name=model_name,
                        prediction=-1,
                        confidence=0.0,
                        logits=[],
                        inference_time=0.0,
                        status="failed",
                        error=str(e),
                    )
                    results.append(result)

        self.results.extend(results)
        return results

    def process_files_async(

        self,

        file_data: List[Tuple[str, str]],

        model_names: List[str],

        target_len: int = 500,

        max_concurrent: int = 3,

    ) -> List[BatchProcessingResult]:
        """Process files asynchronously."""
        results = []

        # Process files in chunks to manage concurrency
        chunk_size = max_concurrent
        file_chunks = [
            file_data[i : i + chunk_size] for i in range(0, len(file_data), chunk_size)
        ]

        for chunk in file_chunks:
            chunk_results = self._process_chunk_async(chunk, model_names, target_len)
            results.extend(chunk_results)

        self.results.extend(results)
        return results

    def _process_chunk_async(

        self, file_chunk: List[Tuple[str, str]], model_names: List[str], target_len: int

    ) -> List[BatchProcessingResult]:
        """Process a chunk of files asynchronously."""
        results = []

        for filename, content in file_chunk:
            try:
                # Parse and preprocess
                x_raw, y_raw = parse_spectrum_data(content)
                x_proc, y_proc = preprocess_spectrum(
                    x_raw, y_raw, modality=self.modality, target_len=target_len
                )

                # Submit async inference for all models
                task_ids = submit_batch_inference(
                    model_names=model_names,
                    input_data=y_proc,
                    inference_func=run_inference,
                )

                # Wait for completion
                inference_results = wait_for_batch_completion(task_ids, timeout=60.0)

                # Process results
                for model_name in model_names:
                    if model_name in inference_results:
                        model_result = inference_results[model_name]

                        if "error" not in model_result:
                            prediction, logits_list, probs, inference_time, logits = (
                                model_result
                            )
                            confidence = max(probs) if probs else 0.0

                            result = BatchProcessingResult(
                                filename=filename,
                                model_name=model_name,
                                prediction=prediction or -1,
                                confidence=confidence,
                                logits=logits_list or [],
                                inference_time=inference_time or 0.0,
                                ground_truth=self._extract_ground_truth(filename),
                            )
                        else:
                            result = BatchProcessingResult(
                                filename=filename,
                                model_name=model_name,
                                prediction=-1,
                                confidence=0.0,
                                logits=[],
                                inference_time=0.0,
                                status="failed",
                                error=model_result["error"],
                            )
                    else:
                        result = BatchProcessingResult(
                            filename=filename,
                            model_name=model_name,
                            prediction=-1,
                            confidence=0.0,
                            logits=[],
                            inference_time=0.0,
                            status="failed",
                            error="No result received",
                        )

                    results.append(result)

            except Exception as e:
                # Create error results for all models
                for model_name in model_names:
                    result = BatchProcessingResult(
                        filename=filename,
                        model_name=model_name,
                        prediction=-1,
                        confidence=0.0,
                        logits=[],
                        inference_time=0.0,
                        status="failed",
                        error=str(e),
                    )
                    results.append(result)

        return results

    def _extract_ground_truth(self, filename: str) -> Optional[int]:
        """Extract ground truth label from filename."""
        try:
            from core_logic import label_file

            return label_file(filename)
        except:
            return None

    def get_summary_statistics(self) -> Dict[str, Any]:
        """Calculate summary statistics for batch processing results."""
        if not self.results:
            return {}

        successful_results = [r for r in self.results if r.status == "success"]
        failed_results = [r for r in self.results if r.status == "failed"]

        stats = {
            "total_files": len(set(r.filename for r in self.results)),
            "total_inferences": len(self.results),
            "successful_inferences": len(successful_results),
            "failed_inferences": len(failed_results),
            "success_rate": (
                len(successful_results) / len(self.results) if self.results else 0
            ),
            "models_used": list(set(r.model_name for r in self.results)),
            "average_inference_time": (
                np.mean([r.inference_time for r in successful_results])
                if successful_results
                else 0
            ),
            "total_processing_time": sum(r.inference_time for r in successful_results),
        }

        # Calculate accuracy if ground truth is available
        gt_results = [r for r in successful_results if r.ground_truth is not None]
        if gt_results:
            correct_predictions = sum(
                1 for r in gt_results if r.prediction == r.ground_truth
            )
            stats["accuracy"] = correct_predictions / len(gt_results)
            stats["samples_with_ground_truth"] = len(gt_results)

        return stats

    def export_results(self, format: str = "csv") -> str:
        """Export results to specified format."""
        # Placeholder implementation to ensure a string is always returned
        return "Export functionality not implemented yet."