File size: 10,521 Bytes
890025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
from sklearn.impute import SimpleImputer
from dotenv import load_dotenv
from scipy import stats
from langchain_groq import ChatGroq  
from langchain.chains import LLMChain
import pandas as pd
import numpy as np
import re
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
import streamlit as st
from .clean_df_fallback  import clean_dataframe_fallback


# # Load environment variables

load_dotenv()



groq_api_key = os.getenv("GROQ_API_KEY")
gemini_api_key = os.getenv("GEMINI_API_KEY")


if not gemini_api_key:
    raise ValueError("GEMINI_API_KEY not found in environment variables")
if not groq_api_key:
    raise ValueError("GROQ_API_KEY not found in environment variables")


# Initialize the LLM model
try:
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.0-flash-lite-preview-02-05", 
        google_api_key=gemini_api_key
    )
    print("Primary Gemini LLM loaded successfully.")

except Exception as e:
    print(f"Error initializing primary Gemini LLM: {e}")
    
    # Fallback to a different LLM from Groq
    try:
        llm = ChatGroq(
            model="gemma2-9b-it",  # replace with your desired Groq model identifier
            groq_api_key=groq_api_key
        )
        print("Fallback Groq LLM loaded successfully.")
    
    except Exception as e2:
        print(f"Error initializing fallback Groq LLM: {e2}")
        llm=None



# Cache the clean_csv function to prevent redundant cleaning
@st.cache_data(ttl=3600, show_spinner=False)
def cached_clean_csv(df_json, skip_cleaning=False):
    """Cached version of the clean_csv function to prevent redundant cleaning.
    
    Args:
        df_json: JSON string representation of the dataframe (for hashing)
        skip_cleaning: Whether to skip cleaning
        
    Returns:
        Tuple of (cleaned_df, insights)
    """
    # Convert JSON back to dataframe
    df = pd.read_json(df_json, orient='records')
    
    # If skip_cleaning is True, return the dataframe as is
    if skip_cleaning:
        return df, "No cleaning performed (user skipped)."
    
    # Reset any test results if we're cleaning a new dataset
    if "test_results_calculated" in st.session_state:
        st.session_state.test_results_calculated = False
        # Clear any previous test metrics to avoid using stale data
        for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
            if key in st.session_state:
                del st.session_state[key]
    
    # Call the actual cleaning function
    return clean_csv(df)


def clean_csv(df):
    """Original clean_csv function that performs the actual cleaning."""
    # ---------------------------
    # Early fallback if LLM initialization failed
    # ---------------------------
    if llm is None:
        print("LLM initialization failed; using hardcoded cleaning function.")
        fallback_df = clean_dataframe_fallback(df)

        return fallback_df , "LLM initialization failed; using hardcoded cleaning function, so no insights were generated."



    # ---------------------------
    # LLM-based cleaning function generation
    # ---------------------------


    # Escape curly braces in the JSON sample and column names
    sample_data = df.head(3).to_json(orient='records')
    escaped_sample_data = sample_data.replace("{", "{{").replace("}", "}}")

    escaped_columns = [
        col.replace("{", "{{").replace("}", "}}") for col in df.columns
    ]
    column_names_str = ", ".join(escaped_columns)



    # Define the prompt for generating the cleaning function
    initial_prompt = PromptTemplate.from_template(f'''
            You are given the following sample data from a pandas DataFrame: 
                {escaped_sample_data}    
              
               column names are : [{column_names_str}].
             
                 Generate a Python function named clean_dataframe(df) considering the following:

                
                1. Performs thorough data cleaning without performing feature engineering. Ensure all necessary cleaning steps are included.
                2. Uses assignment operations (e.g., df = df.drop(...)) and avoids inplace=True for clarity.
                3. First deeply analyze each column’s content this is the most important step , to infer its predominant data type for example if we have RS.2100 in rows remove rs and if we have (89%) remove %  , if the column contains only text and no numbers then it is a text column and if it contains numbers and text then it is a mixed column and if it contains only numbers then it is a numeric column.
                4. For columns that are intended to be numeric but contain extra characters (such as '%' in percentage values, currency symbols like 'Rs.', '$', and commas), remove all non-digit characters (except for the decimal point) and convert them to a numeric type.
                5. For columns that are clearly text or categorical, preserve the content without removing digits or altering the textual information.
                6. Handles missing values appropriately: fill numeric columns with the median (or 0 if the median is not available) and non-numeric columns with 'Unknown'.
                7. For columns where more than 50% of values are strings and less than 10% are numeric, perform conservative string cleaning by removing unwanted special symbols while preserving meaningful digits.
                8. For columns whose names contain 'name', 'Name', or 'Names' (case-insensitive), convert to string type and remove extraneous numeric characters only if they are not part of the essential text.
                9. Preserves other categorical or text columns (such as Gender, City, State, Country, etc.) unless explicitly specified for removal.
                10. Handles edge cases such as completely empty columns appropriately.
                
                Return only the Python code for the function, with no explanations or extra formatting.
                        
               '''
        )



        # Define the refinement prompt
    refine_prompt = PromptTemplate.from_template(
            "The following Python code for cleaning a DataFrame caused an error: {error}\n"
            "Original code:\n{code}\n"
            "Please correct the code to fix the error and ensure it returns a cleaned DataFrame. "
            "Return only the corrected Python code for the function, no explanations or formatting."
        )




        # Create the chains using modern LangChain approach
    initial_chain = initial_prompt | llm
    refine_chain = refine_prompt | llm







    def extract_code(response):
            
            if isinstance(response, str):
                # Handle Markdown or plain text
                if "```python" in response:
                    match = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
                    return match.group(1).strip() if match else response
                
                elif "```" in response:
                    match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
                    return match.group(1).strip() if match else response
                
                return response.strip()
            
            # Handle LLM response objects
            content = getattr(response, 'content', str(response))
            
            if "```python" in content:
                match = re.search(r'```python\n(.*?)\n```', content, re.DOTALL)
                return match.group(1).strip() if match else content
            
            elif "```" in content:
                match = re.search(r'```\n(.*?)\n```', content, re.DOTALL)
                return match.group(1).strip() if match else content
            
            return content.strip()
    




    
    
    try:
        # Generate initial chain and extract the cleaned code 
        cleaning_function_code = extract_code(initial_chain.invoke({}))
        print("Initial generated cleaning function code not executed yet is:\n", cleaning_function_code)

    # Iterative refinement loop with max 5 attempts
        max_attempts = 5

        for attempt in range(max_attempts):
            print(f"Attempt {attempt} code:\n{cleaning_function_code}")  # <-- HERE
            try:
                # Execute the code in global namespace
                exec(cleaning_function_code, globals())               
                # Call the function and assign the result back to df


                if 'clean_dataframe' not in globals():
                    raise NameError("Cleaning function not defined in generated code")

                df = clean_dataframe(df)

                print(f"Cleaning successful on attempt {attempt + 1}")
                break
            
            # if the cleaning fails
            except Exception as e:
                error_message = str(e)
                print(f"Error on attempt {attempt + 1}: {error_message}")
            
            if attempt < max_attempts - 1:
                
                # Refine the code using the error message if there are still epochs left                
                refined_response = refine_chain.invoke({"error": error_message, "code": cleaning_function_code})
                cleaning_function_code = extract_code(refined_response)
                
                print(f"Refined cleaning function code:\n", cleaning_function_code)
            
            else:
                print("Failed to clean DataFrame after 5 maximum attempts")
                # AFter all the failed attempt using the hardcoded logic

                df = clean_dataframe_fallback(df)
            
    except Exception as e:
        print("⚡No successful cleaning done enforcing fallback")
        df = clean_dataframe_fallback(df)

    
    cleaned_df = df     


    insights_prompt = f"""
    Analyze this cleaned dataset:
    - Columns: {cleaned_df.columns.tolist()}
    - Sample data: {cleaned_df.head(3).to_dict()}
    - Numeric stats: {cleaned_df.describe().to_dict()}
    Provide key data quality insights and recommendations.
    """
    
    try:
        insights_response = llm.invoke(insights_prompt)
        analysis_insights = insights_response.content    
    except Exception as e:
        analysis_insights = f"Insight generation failed: {str(e)}"



    # Return the cleaned DataFrame and dummy insights
    return cleaned_df, analysis_insights