yakine commited on
Commit
170c5f1
·
verified ·
1 Parent(s): 9977298

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ import pandas as pd
5
+ import os
6
+ import torch
7
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ from io import StringIO
9
+ from tqdm import tqdm
10
+ import accelerate
11
+ from accelerate import init_empty_weights, disk_offload
12
+
13
+ app = FastAPI()
14
+
15
+ # Access the Hugging Face API token from environment variables
16
+ hf_token = os.getenv('HF_API_TOKEN')
17
+
18
+ if not hf_token:
19
+ raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
20
+
21
+ # Load the GPT-2 tokenizer and model
22
+ tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
23
+ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
24
+
25
+ # Create a pipeline for text generation using GPT-2
26
+ text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
27
+
28
+ # Load the Llama-3 model and tokenizer once during startup
29
+ tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
30
+ model_llama = AutoModelForCausalLM.from_pretrained(
31
+ "meta-llama/Meta-Llama-3.1-8B",
32
+ torch_dtype='auto',
33
+ device_map='auto',
34
+ token=hf_token
35
+ )
36
+
37
+ # Define your prompt template
38
+ prompt_template = """\
39
+ You are an expert in generating synthetic data for machine learning models.
40
+ Your task is to generate a synthetic tabular dataset based on the description provided below.
41
+ Description: {description}
42
+ The dataset should include the following columns: {columns}
43
+ Please provide the data in CSV format with a minimum of 100 rows per generation.
44
+ Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
45
+ Example Description:
46
+ Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
47
+ Example Output:
48
+ Size,Location,Number of Bedrooms,Price
49
+ 1200,Suburban,3,250000
50
+ 900,Urban,2,200000
51
+ 1500,Rural,4,300000
52
+ ...
53
+ Description:
54
+ {description}
55
+ Columns:
56
+ {columns}
57
+ Output: """
58
+
59
+ class DataGenerationRequest(BaseModel):
60
+ description: str
61
+ columns: list
62
+
63
+ def preprocess_user_prompt(user_prompt):
64
+ generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
65
+ return generated_text
66
+
67
+ def format_prompt(description, columns):
68
+ processed_description = preprocess_user_prompt(description)
69
+ prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
70
+ return prompt
71
+
72
+ generation_params = {
73
+ "top_p": 0.90,
74
+ "temperature": 0.8,
75
+ "max_new_tokens": 512,
76
+ }
77
+
78
+ def generate_synthetic_data(description, columns):
79
+ try:
80
+ # Prepare the input for the Llama model
81
+ formatted_prompt = format_prompt(description, columns)
82
+
83
+ # Tokenize the prompt with truncation enabled
84
+ inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
85
+
86
+ # Generate synthetic data
87
+ with torch.no_grad():
88
+ outputs = model_llama.generate(
89
+ **inputs,
90
+ max_length=512,
91
+ top_p=generation_params["top_p"],
92
+ temperature=generation_params["temperature"],
93
+ num_return_sequences=1,
94
+ )
95
+
96
+ # Decode the generated output
97
+ generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
98
+
99
+ # Return the generated synthetic data
100
+ return generated_text
101
+ except Exception as e:
102
+ return f"Error: {e}"
103
+
104
+ @app.post("/generate/")
105
+ def generate_data(request: DataGenerationRequest):
106
+ description = request.description.strip()
107
+ columns = [col.strip() for col in request.columns]
108
+ generated_data = generate_synthetic_data(description, columns)
109
+
110
+ if "Error" in generated_data:
111
+ return JSONResponse(content={"error": generated_data}, status_code=500)
112
+
113
+ # Process the generated CSV data into a DataFrame
114
+ df_synthetic = process_generated_data(generated_data)
115
+ return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
116
+
117
+ def process_generated_data(csv_data):
118
+ data = StringIO(csv_data)
119
+ df = pd.read_csv(data)
120
+ return df
121
+
122
+ @app.get("/")
123
+ def greet_json():
124
+ return {"Hello": "World!"}