mahdee987 commited on
Commit
5671f77
·
verified ·
1 Parent(s): 95c574b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -21
app.py CHANGED
@@ -2,16 +2,29 @@ import os
2
  from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel, Field
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
6
  import torch
 
 
 
 
7
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
8
  os.environ["HF_HOME"] = "/app/cache"
9
  os.environ["XDG_CACHE_HOME"] = "/app/cache"
10
  os.makedirs("/app/cache", exist_ok=True)
 
11
 
12
  app = FastAPI()
13
 
14
- # Enable CORS
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -19,37 +32,47 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- # Load model with caching
23
- model_name = "gpt2"
24
- tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- device_map="auto", # Automatically uses GPU if available
28
- torch_dtype=torch.float16 # Optimize for GPU
29
- )
 
 
 
 
 
 
30
 
 
31
  class Query(BaseModel):
32
  message: str = Field(..., max_length=500)
33
 
 
 
 
 
 
34
  def generate_response(user_message):
35
- prompt = f"User: {user_message}\nAI:"
36
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
37
 
38
- output = model.generate(
39
- input_ids,
40
- max_new_tokens=100,
41
  temperature=0.7,
42
  do_sample=True,
43
- no_repeat_ngram_size=2,
44
  repetition_penalty=1.5,
45
- early_stopping=True,
46
  eos_token_id=tokenizer.eos_token_id
47
  )
48
 
49
- full_response = tokenizer.decode(output[0], skip_special_tokens=True)
50
- response = full_response.split("AI:")[-1].split("\nUser:")[0].strip()
51
- return response or "I'm not sure how to respond to that."
52
 
 
53
  @app.post("/chat")
54
  async def chat(query: Query):
55
  try:
@@ -58,6 +81,76 @@ async def chat(query: Query):
58
  except Exception as e:
59
  return {"error": str(e)}
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  @app.get("/")
62
  def health_check():
63
- return {"status": "OK"}
 
 
 
 
 
2
  from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel, Field
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ Trainer,
9
+ TrainingArguments,
10
+ DataCollatorForLanguageModeling
11
+ )
12
+ from datasets import load_dataset
13
+ from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
14
  import torch
15
+ from datetime import datetime
16
+ import traceback
17
+
18
+ # Environment setup
19
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
20
  os.environ["HF_HOME"] = "/app/cache"
21
  os.environ["XDG_CACHE_HOME"] = "/app/cache"
22
  os.makedirs("/app/cache", exist_ok=True)
23
+ os.makedirs("/app/finetuned", exist_ok=True)
24
 
25
  app = FastAPI()
26
 
27
+ # CORS Configuration
28
  app.add_middleware(
29
  CORSMiddleware,
30
  allow_origins=["*"],
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # Model Loading with error handling
36
+ try:
37
+ model_name = "gpt2"
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ device_map="auto",
44
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
45
+ )
46
+ except Exception as e:
47
+ print(f"Model loading failed: {str(e)}")
48
+ raise
49
 
50
+ # Pydantic Models
51
  class Query(BaseModel):
52
  message: str = Field(..., max_length=500)
53
 
54
+ class FineTuneRequest(BaseModel):
55
+ epochs: int = Field(1, gt=0, le=5)
56
+ learning_rate: float = Field(5e-5, gt=0, le=1e-3)
57
+
58
+ # Response Generation
59
  def generate_response(user_message):
60
+ prompt = f"<FIN_QA>Question: {user_message}\nAnswer:"
61
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
 
63
+ outputs = model.generate(
64
+ **inputs,
65
+ max_new_tokens=150,
66
  temperature=0.7,
67
  do_sample=True,
68
+ no_repeat_ngram_size=3,
69
  repetition_penalty=1.5,
 
70
  eos_token_id=tokenizer.eos_token_id
71
  )
72
 
73
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[-1].strip()
 
 
74
 
75
+ # API Endpoints
76
  @app.post("/chat")
77
  async def chat(query: Query):
78
  try:
 
81
  except Exception as e:
82
  return {"error": str(e)}
83
 
84
+ @app.post("/fine-tune")
85
+ async def fine_tune(params: FineTuneRequest):
86
+ try:
87
+ # Load and combine datasets
88
+ alpaca = load_dataset("gbharti/finance-alpaca", split="train[:20%]") # Sample 20% for demo
89
+ fiqa = load_dataset("bilalRahib/fiqa-personal-finance-dataset", "full", split="train[:20%]")
90
+
91
+ # Formatting function
92
+ def format_example(ex):
93
+ if 'instruction' in ex:
94
+ return {"text": f"Instruction: {ex['instruction']}\nInput: {ex['input']}\nOutput: {ex['output']}"}
95
+ else:
96
+ return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}"}
97
+
98
+ dataset = alpaca.map(format_example) + fiqa.map(format_example)
99
+
100
+ # Tokenize
101
+ def tokenize(ex):
102
+ return tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
103
+
104
+ dataset = dataset.map(tokenize, batched=True)
105
+
106
+ # LoRA Configuration
107
+ peft_config = LoraConfig(
108
+ r=8,
109
+ lora_alpha=16,
110
+ target_modules=["c_attn", "c_proj", "c_fc"],
111
+ lora_dropout=0.05,
112
+ bias="none",
113
+ task_type="CAUSAL_LM"
114
+ )
115
+
116
+ model = prepare_model_for_int8_training(model)
117
+ model = get_peft_model(model, peft_config)
118
+
119
+ # Training
120
+ trainer = Trainer(
121
+ model=model,
122
+ args=TrainingArguments(
123
+ output_dir="/app/finetuned",
124
+ per_device_train_batch_size=2,
125
+ num_train_epochs=params.epochs,
126
+ learning_rate=params.learning_rate,
127
+ logging_dir="/app/logs",
128
+ save_strategy="epoch",
129
+ fp16=torch.cuda.is_available(),
130
+ ),
131
+ train_dataset=dataset,
132
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
133
+ )
134
+
135
+ trainer.train()
136
+ model.save_pretrained("/app/finetuned")
137
+
138
+ return {
139
+ "status": "success",
140
+ "trained_samples": len(dataset),
141
+ "training_time": datetime.now().isoformat()
142
+ }
143
+
144
+ except Exception as e:
145
+ return {
146
+ "error": str(e),
147
+ "traceback": traceback.format_exc()
148
+ }
149
+
150
  @app.get("/")
151
  def health_check():
152
+ return {
153
+ "status": "healthy",
154
+ "model": model_name,
155
+ "device": str(model.device)
156
+ }