Q-bert commited on
Commit
b665872
·
verified ·
1 Parent(s): b5e8200

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -75,7 +75,7 @@ def train_stock_model(stock_symbol, start_date, end_date, feature_range=(10, 100
75
  target_tensors = [t[0] for t in target]
76
 
77
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
- model = StockLlamaForForecasting.from_pretrained("Q-bert/StockLlama").to(device)
79
  config = LoraConfig(
80
  r=64,
81
  lora_alpha=32,
@@ -107,20 +107,20 @@ def train_stock_model(stock_symbol, start_date, end_date, feature_range=(10, 100
107
  weight_decay=0.01,
108
  lr_scheduler_type="linear",
109
  seed=3407,
110
- output_dir=f"StockLlama-LoRA-{stock_symbol}-{start_date}_{end_date}",
111
  ),
112
  )
113
 
114
  trainer.train()
115
 
116
  model = model.merge_and_unload()
117
- model.push_to_hub(f"Q-bert/StockLlama-tuned-{stock_symbol}-{start_date}_{end_date}")
118
  scaler_path = "scaler.joblib"
119
  joblib.dump(scaler, scaler_path)
120
  upload_file(
121
  path_or_fileobj=scaler_path,
122
  path_in_repo=f"scalers/{scaler_path}",
123
- repo_id=f"Q-bert/StockLlama-tuned-{stock_symbol}-{start_date}_{end_date}"
124
  )
125
  return f"Training completed and model saved for {stock_symbol} from {start_date} to {end_date}."
126
 
@@ -142,7 +142,7 @@ def gradio_train_stock_model(stock_symbol, start_date, end_date, feature_range_m
142
  iface = gr.Interface(
143
  fn=gradio_train_stock_model,
144
  inputs=[
145
- gr.Textbox(label="Stock Symbol", value="LUNC-USD"),
146
  gr.Textbox(label="Start Date", value="2023-01-01"),
147
  gr.Textbox(label="End Date", value="2024-08-24"),
148
  gr.Slider(minimum=0, maximum=100, step=1, label="Feature Range Min", value=10),
 
75
  target_tensors = [t[0] for t in target]
76
 
77
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
+ model = StockLlamaForForecasting.from_pretrained("StockLlama/StockLlama").to(device)
79
  config = LoraConfig(
80
  r=64,
81
  lora_alpha=32,
 
107
  weight_decay=0.01,
108
  lr_scheduler_type="linear",
109
  seed=3407,
110
+ output_dir=f"StockLlama/StockLlama-LoRA-{stock_symbol}-{start_date}_{end_date}",
111
  ),
112
  )
113
 
114
  trainer.train()
115
 
116
  model = model.merge_and_unload()
117
+ model.push_to_hub(f"StockLlama/StockLlama-tuned-{stock_symbol}-{start_date}_{end_date}")
118
  scaler_path = "scaler.joblib"
119
  joblib.dump(scaler, scaler_path)
120
  upload_file(
121
  path_or_fileobj=scaler_path,
122
  path_in_repo=f"scalers/{scaler_path}",
123
+ repo_id=f"StockLlama/StockLlama-tuned-{stock_symbol}-{start_date}_{end_date}"
124
  )
125
  return f"Training completed and model saved for {stock_symbol} from {start_date} to {end_date}."
126
 
 
142
  iface = gr.Interface(
143
  fn=gradio_train_stock_model,
144
  inputs=[
145
+ gr.Textbox(label="Stock Symbol", value="BTC-USD"),
146
  gr.Textbox(label="Start Date", value="2023-01-01"),
147
  gr.Textbox(label="End Date", value="2024-08-24"),
148
  gr.Slider(minimum=0, maximum=100, step=1, label="Feature Range Min", value=10),