generate_data / app.py
aledraa's picture
Update app.py
9f6c17a verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import json
import torch
import re
import random
class TableDataGenerator:
def __init__(self, model_name="Qwen/Qwen2.5-3B-Instruct"):
self.model_name = model_name
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_batch_data(self, llm_commands, num_rows=1000, batch_size=30):
"""Generate table data in batches for better performance"""
all_rows = []
# Create column headers description
columns_desc = ", ".join([f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)])
# Calculate number of batches
num_batches = (num_rows + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
current_batch_size = min(batch_size, num_rows - len(all_rows))
# Try multiple attempts to get enough rows for this batch
batch_rows = []
max_attempts = 5
for attempt in range(max_attempts):
remaining_needed = current_batch_size - len(batch_rows)
if remaining_needed <= 0:
break
# Create prompt for this batch
prompt = f"""Generate exactly {remaining_needed} rows of data for a table with columns:
{columns_desc}
Format: [['value1', 'value2'], ['value3', 'value4']]
Requirements:
- Each row must be different and realistic
- Return ONLY the list, no explanations
- Make data diverse and creative
- Seed: {batch_idx * 10 + attempt}
Generate exactly {remaining_needed} rows:"""
messages = [
{"role": "system", "content": "You are a precise data generator. Return only valid Python list format with exactly the requested number of rows."},
{"role": "user", "content": prompt}
]
# Generate response
response = self._generate_response(messages, batch_idx * 10 + attempt)
# Parse the response to extract rows
new_rows = self._parse_response(response, len(llm_commands))
# Add unique rows only
for row in new_rows:
if row not in batch_rows and row not in all_rows:
batch_rows.append(row)
if len(batch_rows) >= current_batch_size:
break
# Add to all rows
all_rows.extend(batch_rows)
# If we still don't have enough, generate fallback data
if len(all_rows) < num_rows and len(batch_rows) < current_batch_size:
fallback_rows = self._generate_fallback_data(llm_commands, current_batch_size - len(batch_rows), len(all_rows))
all_rows.extend(fallback_rows)
# Break if we have enough rows
if len(all_rows) >= num_rows:
break
return all_rows[:num_rows]
def _generate_response(self, messages, seed=None):
"""Generate response from the model"""
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
# Set random seed for variety
if seed is not None:
torch.manual_seed(seed)
else:
torch.manual_seed(random.randint(1, 10000))
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=300,
temperature=0.9,
do_sample=True,
top_p=0.95,
repetition_penalty=1.1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def _parse_response(self, response, expected_columns):
"""Parse the model response to extract table rows"""
rows = []
try:
# Try to find list-like patterns in the response
# Look for patterns like [['value1', 'value2'], ['value3', 'value4']]
list_pattern = r'\[\s*\[.*?\]\s*\]'
matches = re.findall(list_pattern, response, re.DOTALL)
if matches:
# Try to evaluate the largest match
largest_match = max(matches, key=len)
try:
parsed_data = eval(largest_match)
if isinstance(parsed_data, list):
for row in parsed_data:
if isinstance(row, list) and len(row) == expected_columns:
rows.append([str(item) for item in row])
except:
pass
# If no valid list found, try to extract individual rows
if not rows:
# Look for individual row patterns like ['value1', 'value2']
row_pattern = r'\[([^\[\]]+)\]'
row_matches = re.findall(row_pattern, response)
for match in row_matches:
try:
# Split by comma and clean up
items = [item.strip().strip('"\'') for item in match.split(',')]
if len(items) == expected_columns:
rows.append(items)
except:
continue
except Exception as e:
print(f"Error parsing response: {e}")
return rows
def _generate_fallback_data(self, llm_commands, needed_rows, current_count):
"""Generate fallback data when LLM doesn't produce enough rows"""
fallback_rows = []
# Simple fallback generators based on command type
for i in range(needed_rows):
row = []
for cmd in llm_commands:
cmd_lower = cmd.lower()
if 'age' in cmd_lower:
if 'between' in cmd_lower and '1' in cmd_lower and '20' in cmd_lower:
row.append(str(random.randint(1, 20)))
else:
row.append(str(random.randint(18, 65)))
elif 'arabic' in cmd_lower and 'name' in cmd_lower:
arabic_names = ['محمد', 'أحمد', 'عبدالله', 'خالد', 'سعد', 'فهد', 'عبدالعزيز', 'ناصر', 'سلطان', 'طلال',
'فاطمة', 'عائشة', 'خديجة', 'مريم', 'زينب', 'سارة', 'نورا', 'هند', 'لطيفة', 'منى']
row.append(random.choice(arabic_names))
elif 'name' in cmd_lower:
names = ['John', 'Jane', 'Michael', 'Sarah', 'David', 'Lisa', 'Robert', 'Emily', 'James', 'Jessica']
row.append(random.choice(names))
elif 'price' in cmd_lower or 'cost' in cmd_lower:
row.append(str(random.randint(10, 1000)))
elif 'city' in cmd_lower:
cities = ['New York', 'London', 'Tokyo', 'Paris', 'Sydney', 'Cairo', 'Dubai', 'Berlin', 'Rome', 'Madrid']
row.append(random.choice(cities))
else:
row.append(f"data_{current_count + i + 1}")
fallback_rows.append(row)
return fallback_rows
def generate_table_data(json_input, num_rows=1000):
"""Main function to generate table data from JSON input"""
try:
# Parse JSON input
data = json.loads(json_input)
llm_commands = data.get('llm_commands', [])
if not llm_commands:
return "Error: No llm_commands found in JSON input", []
# Initialize generator
generator = TableDataGenerator()
# Generate data
rows = generator.generate_batch_data(llm_commands, num_rows)
# Create JSON structure
json_data = {
"columns": llm_commands,
"rows": rows,
"total_rows": len(rows)
}
# Save to JSON file with proper Arabic encoding
import os
os.makedirs('./train', exist_ok=True)
json.dump(json_data, open('./train/data.json', "w", encoding="utf-8"), ensure_ascii=False, indent=2)
# Format output
result = f"Generated {len(rows)} rows (requested: {num_rows}):\n"
result += f"Columns: {llm_commands}\n"
result += f"Saved to: ./train/data.json\n\n"
# Show first 10 rows as preview
result += "First 10 rows:\n"
for i, row in enumerate(rows[:10]):
result += f"{i+1}: {row}\n"
if len(rows) > 10:
result += f"\n... and {len(rows) - 10} more rows"
return result, json_data
except json.JSONDecodeError:
return "Error: Invalid JSON format", {}
except Exception as e:
return f"Error: {str(e)}", {}
# Gradio Interface
def process_json_input(json_input, num_rows):
"""Process JSON input and return formatted results"""
result_text, json_data = generate_table_data(json_input, int(num_rows))
# Return JSON content for download
if json_data and 'rows' in json_data:
json_content = json.dumps(json_data, ensure_ascii=False, indent=2)
return result_text, json_content
else:
return result_text, ""
# Create Gradio interface
with gr.Blocks(title="Table Data Generator") as demo:
gr.Markdown("# Table Data Generator using LLM")
gr.Markdown("Generate realistic table data based on column descriptions")
with gr.Row():
with gr.Column():
json_input = gr.Textbox(
label="JSON Input",
placeholder='{"llm_commands": ["ages between 1 to 20", "arabic name"]}',
lines=3,
value='{"llm_commands": ["ages between 1 to 20", "arabic name"]}'
)
num_rows = gr.Slider(
minimum=10,
maximum=2000,
value=100,
step=10,
label="Number of rows to generate"
)
generate_btn = gr.Button("Generate Data", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Data Preview",
lines=15,
max_lines=20
)
download_json = gr.File(
label="Download JSON",
visible=True
)
def generate_and_save(json_input, num_rows):
result_text, json_content = process_json_input(json_input, num_rows)
if json_content:
# Save to temporary file
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f:
f.write(json_content)
temp_path = f.name
return result_text, temp_path
else:
return result_text, None
generate_btn.click(
fn=generate_and_save,
inputs=[json_input, num_rows],
outputs=[output_text, download_json]
)
# Example inputs
gr.Examples(
examples=[
['{"llm_commands": ["ages between 1 to 20", "arabic name"]}', 50],
['{"llm_commands": ["random city", "population number", "country"]}', 100],
['{"llm_commands": ["product name", "price in USD", "category"]}', 75],
['{"llm_commands": ["email address", "phone number", "job title"]}', 60]
],
inputs=[json_input, num_rows]
)
if __name__ == "__main__":
demo.launch()
# Example usage:
# json_input = '{"llm_commands": ["ages between 1 to 20", "arabic name"]}'
# result_text, json_data = generate_table_data(json_input, 1000)
# print(result_text)
# print(f"Actual rows generated: {len(json_data.get('rows', []))}")