Anupam202224 commited on
Commit
c4c8dcf
·
verified ·
1 Parent(s): 8fa43d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -68
app.py CHANGED
@@ -1,26 +1,29 @@
1
  import os
2
  import shutil
3
  import gradio as gr
4
- from transformers import ReactCodeAgent, HfEngine, Tool
5
  import pandas as pd
6
- import spaces
7
  import torch
8
-
9
- from gradio import Chatbot
10
- from streaming import stream_to_gradio
11
- from huggingface_hub import login
12
- from gradio.data_classes import FileData
13
-
14
-
15
- llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-70B-Instruct")
16
-
17
- agent = ReactCodeAgent(
18
- tools=[],
19
- llm_engine=llm_engine,
20
- additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "scipy.stats"],
21
- max_iterations=10,
22
- )
23
-
 
 
 
 
24
  base_prompt = """You are an expert data analyst.
25
  According to the features you have and the data structure given below, determine which feature should be the target.
26
  Then list 3 interesting questions that could be asked on this data, for instance about specific correlations with target variable.
@@ -38,25 +41,24 @@ The data file is passed to you as the variable data_file, it is a pandas datafra
38
  DO NOT try to load data_file, it is already a dataframe pre-loaded in your python interpreter!
39
  """
40
 
41
- example_notes="""This data is about the Titanic wreck in 1912.
42
- The target figure is the survival of passengers, notes by 'Survived'
43
  pclass: A proxy for socio-economic status (SES)
44
  1st = Upper
45
  2nd = Middle
46
  3rd = Lower
47
- age: Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5
48
  sibsp: The dataset defines family relations in this way...
49
  Sibling = brother, sister, stepbrother, stepsister
50
  Spouse = husband, wife (mistresses and fiancés were ignored)
51
  parch: The dataset defines family relations in this way...
52
  Parent = mother, father
53
  Child = daughter, son, stepdaughter, stepson
54
- Some children travelled only with a nanny, therefore parch=0 for them."""
55
 
56
- @spaces.GPU
57
  def get_images_in_directory(directory):
 
58
  image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
59
-
60
  image_files = []
61
  for root, dirs, files in os.walk(directory):
62
  for file in files:
@@ -64,73 +66,105 @@ def get_images_in_directory(directory):
64
  image_files.append(os.path.join(root, file))
65
  return image_files
66
 
67
- @spaces.GPU
68
- def interact_with_agent(file_input, additional_notes):
69
- shutil.rmtree("./figures")
70
- os.makedirs("./figures")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- data_file = pd.read_csv(file_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  data_structure_notes = f"""- Description (output of .describe()):
74
- {data_file.describe()}
75
- - Columns with dtypes:
76
- {data_file.dtypes}"""
77
 
 
78
  prompt = base_prompt.format(structure_notes=data_structure_notes)
79
 
80
- if additional_notes and len(additional_notes) > 0:
81
  prompt += "\nAdditional notes on the data:\n" + additional_notes
82
 
83
- messages = [gr.ChatMessage(role="user", content=prompt)]
84
- yield messages + [
85
- gr.ChatMessage(role="assistant", content="⏳ _Starting task..._")
86
- ]
87
-
88
- plot_image_paths = {}
89
- for msg in stream_to_gradio(agent, prompt, data_file=data_file):
90
- messages.append(msg)
91
- for image_path in get_images_in_directory("./figures"):
92
- if image_path not in plot_image_paths:
93
- image_message = gr.ChatMessage(
94
- role="assistant",
95
- content=FileData(path=image_path, mime_type="image/png"),
96
- )
97
- plot_image_paths[image_path] = True
98
- messages.append(image_message)
99
- yield messages + [
100
- gr.ChatMessage(role="assistant", content="⏳ _Still processing..._")
101
- ]
102
- yield messages
103
 
 
 
 
104
 
 
 
 
 
 
 
 
 
105
  with gr.Blocks(
106
  theme=gr.themes.Soft(
107
  primary_hue=gr.themes.colors.yellow,
108
  secondary_hue=gr.themes.colors.blue,
109
  )
110
  ) as demo:
111
- gr.Markdown("""# Llama-3.1 Data analyst 📊🤔
112
-
113
- Drop a `.csv` file below, add notes to describe this data if needed, and **Llama-3.1-70B will analyze the file content and draw figures for you!**""")
114
- file_input = gr.File(label="Your file to analyze")
115
- text_input = gr.Textbox(
116
- label="Additional notes to support the analysis"
117
- )
 
 
 
 
118
  submit = gr.Button("Run analysis!", variant="primary")
 
119
  chatbot = gr.Chatbot(
120
  label="Data Analyst Agent",
121
- type="messages",
122
- avatar_images=(
123
- None,
124
- "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
125
- ),
126
  )
 
127
  gr.Examples(
128
  examples=[["./example/titanic.csv", example_notes]],
129
  inputs=[file_input, text_input],
130
  cache_examples=False
131
  )
 
 
 
 
 
 
 
 
132
 
133
- submit.click(interact_with_agent, [file_input, text_input], [chatbot])
134
-
135
  if __name__ == "__main__":
136
- demo.launch()
 
1
  import os
2
  import shutil
3
  import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import pandas as pd
 
6
  import torch
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+
10
+ # Define constants
11
+ MODEL_NAME = "meta-llama/Llama-2-7b-hf" # Replace with a smaller model suitable for CPU
12
+ FIGURES_DIR = "./figures"
13
+
14
+ # Ensure the figures directory exists
15
+ os.makedirs(FIGURES_DIR, exist_ok=True)
16
+
17
+ # Initialize tokenizer and model
18
+ # Note: Loading large models on CPU can be very slow and may not be feasible
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu")
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ exit(1)
25
+
26
+ # Define the base prompt
27
  base_prompt = """You are an expert data analyst.
28
  According to the features you have and the data structure given below, determine which feature should be the target.
29
  Then list 3 interesting questions that could be asked on this data, for instance about specific correlations with target variable.
 
41
  DO NOT try to load data_file, it is already a dataframe pre-loaded in your python interpreter!
42
  """
43
 
44
+ example_notes = """This data is about the Titanic wreck in 1912.
45
+ The target figure is the survival of passengers, noted by 'Survived'.
46
  pclass: A proxy for socio-economic status (SES)
47
  1st = Upper
48
  2nd = Middle
49
  3rd = Lower
50
+ age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5
51
  sibsp: The dataset defines family relations in this way...
52
  Sibling = brother, sister, stepbrother, stepsister
53
  Spouse = husband, wife (mistresses and fiancés were ignored)
54
  parch: The dataset defines family relations in this way...
55
  Parent = mother, father
56
  Child = daughter, son, stepdaughter, stepson
57
+ Some children traveled only with a nanny, therefore parch=0 for them."""
58
 
 
59
  def get_images_in_directory(directory):
60
+ """Retrieve all image file paths from the specified directory."""
61
  image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
 
62
  image_files = []
63
  for root, dirs, files in os.walk(directory):
64
  for file in files:
 
66
  image_files.append(os.path.join(root, file))
67
  return image_files
68
 
69
+ def generate_response(prompt):
70
+ """Generate a response from the language model based on the prompt."""
71
+ inputs = tokenizer(prompt, return_tensors="pt")
72
+ inputs = inputs.to('cpu') # Ensure the model runs on CPU
73
+
74
+ # Generate response (adjust parameters as needed)
75
+ with torch.no_grad():
76
+ outputs = model.generate(
77
+ **inputs,
78
+ max_length=2048,
79
+ do_sample=True,
80
+ top_p=0.95,
81
+ temperature=0.7,
82
+ eos_token_id=tokenizer.eos_token_id
83
+ )
84
+
85
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+ return response
87
 
88
+ def interact_with_agent(file_input, additional_notes):
89
+ """Process the uploaded file and interact with the language model to analyze data."""
90
+ # Clear and recreate the figures directory
91
+ if os.path.exists(FIGURES_DIR):
92
+ shutil.rmtree(FIGURES_DIR)
93
+ os.makedirs(FIGURES_DIR, exist_ok=True)
94
+
95
+ # Load the data file into a pandas dataframe
96
+ try:
97
+ data_file = pd.read_csv(file_input.name)
98
+ except Exception as e:
99
+ yield [("Error loading CSV file.",)]
100
+ return
101
+
102
+ # Create structure notes
103
  data_structure_notes = f"""- Description (output of .describe()):
104
+ {data_file.describe()}
105
+ - Columns with dtypes:
106
+ {data_file.dtypes}"""
107
 
108
+ # Construct the prompt
109
  prompt = base_prompt.format(structure_notes=data_structure_notes)
110
 
111
+ if additional_notes and additional_notes.strip():
112
  prompt += "\nAdditional notes on the data:\n" + additional_notes
113
 
114
+ # Initialize chat history
115
+ messages = [("User", prompt)]
116
+ yield messages + [("Assistant", "⏳ _Starting analysis..._")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Generate response from the model
119
+ response = generate_response(prompt)
120
+ messages.append(("Assistant", response))
121
 
122
+ # Extract and display generated images
123
+ image_paths = get_images_in_directory(FIGURES_DIR)
124
+ for image_path in image_paths:
125
+ messages.append(("Assistant", gr.Image.update(value=image_path)))
126
+
127
+ yield messages
128
+
129
+ # Define the Gradio interface
130
  with gr.Blocks(
131
  theme=gr.themes.Soft(
132
  primary_hue=gr.themes.colors.yellow,
133
  secondary_hue=gr.themes.colors.blue,
134
  )
135
  ) as demo:
136
+ gr.Markdown("""# Llama-2 Data Analyst 📊🤔
137
+
138
+ Drop a `.csv` file below, add notes to describe this data if needed, and **the model will analyze the file content and draw figures for you!**""")
139
+
140
+ with gr.Row():
141
+ file_input = gr.File(label="Your file to analyze", type="file")
142
+ text_input = gr.Textbox(
143
+ label="Additional notes to support the analysis",
144
+ placeholder="Enter any additional notes here..."
145
+ )
146
+
147
  submit = gr.Button("Run analysis!", variant="primary")
148
+
149
  chatbot = gr.Chatbot(
150
  label="Data Analyst Agent",
151
+ height=400,
 
 
 
 
152
  )
153
+
154
  gr.Examples(
155
  examples=[["./example/titanic.csv", example_notes]],
156
  inputs=[file_input, text_input],
157
  cache_examples=False
158
  )
159
+
160
+ # Connect the submit button to the interact_with_agent function
161
+ submit.click(
162
+ interact_with_agent,
163
+ inputs=[file_input, text_input],
164
+ outputs=[chatbot],
165
+ show_progress=True
166
+ )
167
 
168
+ # Launch the Gradio app
 
169
  if __name__ == "__main__":
170
+ demo.launch()