Update LLM.py
Browse files
LLM.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os, traceback, asyncio, json
|
| 2 |
from datetime import datetime
|
| 3 |
from functools import wraps
|
|
@@ -5,7 +6,7 @@ from backoff import on_exception, expo
|
|
| 5 |
from openai import OpenAI, RateLimitError, APITimeoutError
|
| 6 |
import numpy as np
|
| 7 |
from sentiment_news import NewsFetcher
|
| 8 |
-
from helpers import parse_json_from_response, validate_required_fields, format_technical_indicators, format_strategy_scores
|
| 9 |
|
| 10 |
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
|
| 11 |
PRIMARY_MODEL = "nvidia/llama-3.1-nemotron-ultra-253b-v1"
|
|
@@ -15,7 +16,8 @@ class PatternAnalysisEngine:
|
|
| 15 |
self.llm = llm_service
|
| 16 |
|
| 17 |
def _format_chart_data_for_llm(self, ohlcv_data):
|
| 18 |
-
if not ohlcv_data or len(ohlcv_data) < 20:
|
|
|
|
| 19 |
try:
|
| 20 |
candles_to_analyze = ohlcv_data[-50:] if len(ohlcv_data) > 50 else ohlcv_data
|
| 21 |
chart_description = ["CANDLE DATA FOR PATTERN ANALYSIS:", f"Total candles available: {len(ohlcv_data)}", f"Candles used for analysis: {len(candles_to_analyze)}", ""]
|
|
@@ -49,7 +51,8 @@ class PatternAnalysisEngine:
|
|
| 49 |
chart_description.extend(["", "VOLUME ANALYSIS:", f"Current Volume: {current_volume:,.0f}", f"Volume Ratio: {volume_ratio:.2f}x average", f"Volume Signal: {volume_signal}"])
|
| 50 |
|
| 51 |
return "\n".join(chart_description)
|
| 52 |
-
except Exception as e:
|
|
|
|
| 53 |
|
| 54 |
async def analyze_chart_patterns(self, symbol, ohlcv_data):
|
| 55 |
try:
|
|
@@ -68,11 +71,13 @@ class PatternAnalysisEngine:
|
|
| 68 |
def _parse_pattern_response(self, response_text):
|
| 69 |
try:
|
| 70 |
json_str = parse_json_from_response(response_text)
|
| 71 |
-
if not json_str:
|
|
|
|
| 72 |
|
| 73 |
pattern_data = json.loads(json_str)
|
| 74 |
required = ['pattern_detected', 'pattern_confidence', 'predicted_direction']
|
| 75 |
-
if not validate_required_fields(pattern_data, required):
|
|
|
|
| 76 |
|
| 77 |
return pattern_data
|
| 78 |
except Exception as e:
|
|
@@ -88,11 +93,13 @@ class LLMService:
|
|
| 88 |
self.news_fetcher = NewsFetcher()
|
| 89 |
self.pattern_engine = PatternAnalysisEngine(self)
|
| 90 |
self.semaphore = asyncio.Semaphore(5)
|
|
|
|
| 91 |
|
| 92 |
def _rate_limit_nvidia_api(func):
|
| 93 |
@wraps(func)
|
| 94 |
@on_exception(expo, RateLimitError, max_tries=5)
|
| 95 |
-
async def wrapper(*args, **kwargs):
|
|
|
|
| 96 |
return wrapper
|
| 97 |
|
| 98 |
async def get_trading_decision(self, data_payload: dict):
|
|
@@ -104,29 +111,51 @@ class LLMService:
|
|
| 104 |
pattern_analysis = await self._get_pattern_analysis(data_payload)
|
| 105 |
prompt = self._create_enhanced_trading_prompt(data_payload, news_text, pattern_analysis)
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
decision_dict = self._parse_llm_response_enhanced(response, target_strategy, symbol)
|
| 110 |
if decision_dict:
|
| 111 |
decision_dict['model_source'] = self.model_name
|
| 112 |
decision_dict['pattern_analysis'] = pattern_analysis
|
| 113 |
return decision_dict
|
| 114 |
-
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
except Exception as e:
|
| 116 |
-
print(f"
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
def _parse_llm_response_enhanced(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
|
| 120 |
try:
|
| 121 |
json_str = parse_json_from_response(response_text)
|
| 122 |
-
if not json_str:
|
|
|
|
| 123 |
|
| 124 |
decision_data = json.loads(json_str)
|
| 125 |
required_fields = ['action', 'reasoning', 'risk_assessment', 'trade_type', 'stop_loss', 'take_profit', 'expected_target_minutes', 'confidence_level']
|
| 126 |
-
if not validate_required_fields(decision_data, required_fields):
|
|
|
|
| 127 |
|
| 128 |
strategy_value = decision_data.get('strategy')
|
| 129 |
-
if not strategy_value or strategy_value == 'unknown':
|
|
|
|
| 130 |
|
| 131 |
return decision_data
|
| 132 |
except Exception as e:
|
|
@@ -138,11 +167,13 @@ class LLMService:
|
|
| 138 |
symbol = data_payload['symbol']
|
| 139 |
if 'raw_ohlcv' in data_payload and '1h' in data_payload['raw_ohlcv']:
|
| 140 |
ohlcv_data = data_payload['raw_ohlcv']['1h']
|
| 141 |
-
if ohlcv_data and len(ohlcv_data) >= 20:
|
|
|
|
| 142 |
|
| 143 |
if 'advanced_indicators' in data_payload and '1h' in data_payload['advanced_indicators']:
|
| 144 |
ohlcv_data = data_payload['advanced_indicators']['1h']
|
| 145 |
-
if ohlcv_data and len(ohlcv_data) >= 20:
|
|
|
|
| 146 |
|
| 147 |
return None
|
| 148 |
except Exception as e:
|
|
@@ -214,7 +245,8 @@ class LLMService:
|
|
| 214 |
return prompt
|
| 215 |
|
| 216 |
def _format_pattern_analysis(self, pattern_analysis):
|
| 217 |
-
if not pattern_analysis:
|
|
|
|
| 218 |
confidence = pattern_analysis.get('pattern_confidence', 0)
|
| 219 |
pattern_name = pattern_analysis.get('pattern_detected', 'unknown')
|
| 220 |
analysis_lines = [f"Pattern: {pattern_name}", f"Confidence: {confidence:.1%}", f"Predicted Move: {pattern_analysis.get('predicted_direction', 'N/A')}", f"Analysis: {pattern_analysis.get('pattern_analysis', 'No detailed analysis')}"]
|
|
@@ -234,25 +266,45 @@ class LLMService:
|
|
| 234 |
pattern_analysis = await self._get_pattern_analysis(processed_data)
|
| 235 |
prompt = self._create_re_analysis_prompt(trade_data, processed_data, news_text, pattern_analysis)
|
| 236 |
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
re_analysis_dict = self._parse_re_analysis_response(response, original_strategy, symbol)
|
| 240 |
if re_analysis_dict:
|
| 241 |
re_analysis_dict['model_source'] = self.model_name
|
| 242 |
return re_analysis_dict
|
| 243 |
-
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
except Exception as e:
|
| 245 |
-
print(f"
|
| 246 |
-
|
|
|
|
| 247 |
|
| 248 |
def _parse_re_analysis_response(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
|
| 249 |
try:
|
| 250 |
json_str = parse_json_from_response(response_text)
|
| 251 |
-
if not json_str:
|
|
|
|
| 252 |
|
| 253 |
decision_data = json.loads(json_str)
|
| 254 |
strategy_value = decision_data.get('strategy')
|
| 255 |
-
if not strategy_value or strategy_value == 'unknown':
|
|
|
|
| 256 |
|
| 257 |
return decision_data
|
| 258 |
except Exception as e:
|
|
@@ -265,8 +317,11 @@ class LLMService:
|
|
| 265 |
current_price = processed_data.get('current_price', 'N/A')
|
| 266 |
strategy = trade_data.get('strategy', 'GENERIC')
|
| 267 |
|
| 268 |
-
try:
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
indicators_summary = format_technical_indicators(processed_data.get('advanced_indicators', {}))
|
| 272 |
pattern_summary = self._format_pattern_analysis(pattern_analysis)
|
|
|
|
| 1 |
+
# LLM.py
|
| 2 |
import os, traceback, asyncio, json
|
| 3 |
from datetime import datetime
|
| 4 |
from functools import wraps
|
|
|
|
| 6 |
from openai import OpenAI, RateLimitError, APITimeoutError
|
| 7 |
import numpy as np
|
| 8 |
from sentiment_news import NewsFetcher
|
| 9 |
+
from helpers import parse_json_from_response, validate_required_fields, format_technical_indicators, format_strategy_scores
|
| 10 |
|
| 11 |
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
|
| 12 |
PRIMARY_MODEL = "nvidia/llama-3.1-nemotron-ultra-253b-v1"
|
|
|
|
| 16 |
self.llm = llm_service
|
| 17 |
|
| 18 |
def _format_chart_data_for_llm(self, ohlcv_data):
|
| 19 |
+
if not ohlcv_data or len(ohlcv_data) < 20:
|
| 20 |
+
return "Insufficient chart data for pattern analysis"
|
| 21 |
try:
|
| 22 |
candles_to_analyze = ohlcv_data[-50:] if len(ohlcv_data) > 50 else ohlcv_data
|
| 23 |
chart_description = ["CANDLE DATA FOR PATTERN ANALYSIS:", f"Total candles available: {len(ohlcv_data)}", f"Candles used for analysis: {len(candles_to_analyze)}", ""]
|
|
|
|
| 51 |
chart_description.extend(["", "VOLUME ANALYSIS:", f"Current Volume: {current_volume:,.0f}", f"Volume Ratio: {volume_ratio:.2f}x average", f"Volume Signal: {volume_signal}"])
|
| 52 |
|
| 53 |
return "\n".join(chart_description)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
return f"Error formatting chart data: {str(e)}"
|
| 56 |
|
| 57 |
async def analyze_chart_patterns(self, symbol, ohlcv_data):
|
| 58 |
try:
|
|
|
|
| 71 |
def _parse_pattern_response(self, response_text):
|
| 72 |
try:
|
| 73 |
json_str = parse_json_from_response(response_text)
|
| 74 |
+
if not json_str:
|
| 75 |
+
return {"pattern_detected": "parse_error", "pattern_confidence": 0.1, "pattern_analysis": "Could not parse pattern analysis response"}
|
| 76 |
|
| 77 |
pattern_data = json.loads(json_str)
|
| 78 |
required = ['pattern_detected', 'pattern_confidence', 'predicted_direction']
|
| 79 |
+
if not validate_required_fields(pattern_data, required):
|
| 80 |
+
return {"pattern_detected": "incomplete_data", "pattern_confidence": 0.1, "pattern_analysis": "Incomplete pattern analysis data"}
|
| 81 |
|
| 82 |
return pattern_data
|
| 83 |
except Exception as e:
|
|
|
|
| 93 |
self.news_fetcher = NewsFetcher()
|
| 94 |
self.pattern_engine = PatternAnalysisEngine(self)
|
| 95 |
self.semaphore = asyncio.Semaphore(5)
|
| 96 |
+
self.r2_service = None # سيتم تعيينه من app.py
|
| 97 |
|
| 98 |
def _rate_limit_nvidia_api(func):
|
| 99 |
@wraps(func)
|
| 100 |
@on_exception(expo, RateLimitError, max_tries=5)
|
| 101 |
+
async def wrapper(*args, **kwargs):
|
| 102 |
+
return await func(*args, **kwargs)
|
| 103 |
return wrapper
|
| 104 |
|
| 105 |
async def get_trading_decision(self, data_payload: dict):
|
|
|
|
| 111 |
pattern_analysis = await self._get_pattern_analysis(data_payload)
|
| 112 |
prompt = self._create_enhanced_trading_prompt(data_payload, news_text, pattern_analysis)
|
| 113 |
|
| 114 |
+
# ✅ حفظ الـ Prompt في R2 قبل إرساله للنموذج
|
| 115 |
+
if self.r2_service:
|
| 116 |
+
analysis_data = {
|
| 117 |
+
'current_price': data_payload.get('current_price'),
|
| 118 |
+
'final_score': data_payload.get('final_score'),
|
| 119 |
+
'enhanced_final_score': data_payload.get('enhanced_final_score'),
|
| 120 |
+
'target_strategy': target_strategy,
|
| 121 |
+
'pattern_analysis': pattern_analysis
|
| 122 |
+
}
|
| 123 |
+
await self.r2_service.save_llm_prompts_async(
|
| 124 |
+
symbol, 'trading_decision', prompt, analysis_data
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
async with self.semaphore:
|
| 128 |
+
response = await self._call_llm(prompt)
|
| 129 |
|
| 130 |
decision_dict = self._parse_llm_response_enhanced(response, target_strategy, symbol)
|
| 131 |
if decision_dict:
|
| 132 |
decision_dict['model_source'] = self.model_name
|
| 133 |
decision_dict['pattern_analysis'] = pattern_analysis
|
| 134 |
return decision_dict
|
| 135 |
+
else:
|
| 136 |
+
# ❌ لا نستخدم أي محاكاة - نرجع None في حالة الفشل
|
| 137 |
+
print(f"❌ فشل تحليل النموذج الضخم لـ {symbol} - لا توجد قرارات بديلة")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
except Exception as e:
|
| 141 |
+
print(f"❌ خطأ في قرار التداول لـ {data_payload.get('symbol', 'unknown')}: {e}")
|
| 142 |
+
# ❌ لا نستخدم أي محاكاة
|
| 143 |
+
return None
|
| 144 |
|
| 145 |
def _parse_llm_response_enhanced(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
|
| 146 |
try:
|
| 147 |
json_str = parse_json_from_response(response_text)
|
| 148 |
+
if not json_str:
|
| 149 |
+
return None
|
| 150 |
|
| 151 |
decision_data = json.loads(json_str)
|
| 152 |
required_fields = ['action', 'reasoning', 'risk_assessment', 'trade_type', 'stop_loss', 'take_profit', 'expected_target_minutes', 'confidence_level']
|
| 153 |
+
if not validate_required_fields(decision_data, required_fields):
|
| 154 |
+
return None
|
| 155 |
|
| 156 |
strategy_value = decision_data.get('strategy')
|
| 157 |
+
if not strategy_value or strategy_value == 'unknown':
|
| 158 |
+
decision_data['strategy'] = fallback_strategy
|
| 159 |
|
| 160 |
return decision_data
|
| 161 |
except Exception as e:
|
|
|
|
| 167 |
symbol = data_payload['symbol']
|
| 168 |
if 'raw_ohlcv' in data_payload and '1h' in data_payload['raw_ohlcv']:
|
| 169 |
ohlcv_data = data_payload['raw_ohlcv']['1h']
|
| 170 |
+
if ohlcv_data and len(ohlcv_data) >= 20:
|
| 171 |
+
return await self.pattern_engine.analyze_chart_patterns(symbol, ohlcv_data)
|
| 172 |
|
| 173 |
if 'advanced_indicators' in data_payload and '1h' in data_payload['advanced_indicators']:
|
| 174 |
ohlcv_data = data_payload['advanced_indicators']['1h']
|
| 175 |
+
if ohlcv_data and len(ohlcv_data) >= 20:
|
| 176 |
+
return await self.pattern_engine.analyze_chart_patterns(symbol, ohlcv_data)
|
| 177 |
|
| 178 |
return None
|
| 179 |
except Exception as e:
|
|
|
|
| 245 |
return prompt
|
| 246 |
|
| 247 |
def _format_pattern_analysis(self, pattern_analysis):
|
| 248 |
+
if not pattern_analysis:
|
| 249 |
+
return "No clear patterns detected"
|
| 250 |
confidence = pattern_analysis.get('pattern_confidence', 0)
|
| 251 |
pattern_name = pattern_analysis.get('pattern_detected', 'unknown')
|
| 252 |
analysis_lines = [f"Pattern: {pattern_name}", f"Confidence: {confidence:.1%}", f"Predicted Move: {pattern_analysis.get('predicted_direction', 'N/A')}", f"Analysis: {pattern_analysis.get('pattern_analysis', 'No detailed analysis')}"]
|
|
|
|
| 266 |
pattern_analysis = await self._get_pattern_analysis(processed_data)
|
| 267 |
prompt = self._create_re_analysis_prompt(trade_data, processed_data, news_text, pattern_analysis)
|
| 268 |
|
| 269 |
+
# ✅ حفظ الـ Prompt في R2 قبل إرساله للنموذج
|
| 270 |
+
if self.r2_service:
|
| 271 |
+
analysis_data = {
|
| 272 |
+
'entry_price': trade_data.get('entry_price'),
|
| 273 |
+
'current_price': processed_data.get('current_price'),
|
| 274 |
+
'original_strategy': original_strategy,
|
| 275 |
+
'pattern_analysis': pattern_analysis
|
| 276 |
+
}
|
| 277 |
+
await self.r2_service.save_llm_prompts_async(
|
| 278 |
+
symbol, 'trade_reanalysis', prompt, analysis_data
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
async with self.semaphore:
|
| 282 |
+
response = await self._call_llm(prompt)
|
| 283 |
|
| 284 |
re_analysis_dict = self._parse_re_analysis_response(response, original_strategy, symbol)
|
| 285 |
if re_analysis_dict:
|
| 286 |
re_analysis_dict['model_source'] = self.model_name
|
| 287 |
return re_analysis_dict
|
| 288 |
+
else:
|
| 289 |
+
# ❌ لا نستخدم أي محاكاة - نرجع None في حالة الفشل
|
| 290 |
+
print(f"❌ فشل إعادة تحليل ا��نموذج الضخم لـ {symbol} - لا توجد قرارات بديلة")
|
| 291 |
+
return None
|
| 292 |
+
|
| 293 |
except Exception as e:
|
| 294 |
+
print(f"❌ خطأ في إعادة تحليل LLM: {e}")
|
| 295 |
+
# ❌ لا نستخدم أي محاكاة
|
| 296 |
+
return None
|
| 297 |
|
| 298 |
def _parse_re_analysis_response(self, response_text: str, fallback_strategy: str, symbol: str) -> dict:
|
| 299 |
try:
|
| 300 |
json_str = parse_json_from_response(response_text)
|
| 301 |
+
if not json_str:
|
| 302 |
+
return None
|
| 303 |
|
| 304 |
decision_data = json.loads(json_str)
|
| 305 |
strategy_value = decision_data.get('strategy')
|
| 306 |
+
if not strategy_value or strategy_value == 'unknown':
|
| 307 |
+
decision_data['strategy'] = fallback_strategy
|
| 308 |
|
| 309 |
return decision_data
|
| 310 |
except Exception as e:
|
|
|
|
| 317 |
current_price = processed_data.get('current_price', 'N/A')
|
| 318 |
strategy = trade_data.get('strategy', 'GENERIC')
|
| 319 |
|
| 320 |
+
try:
|
| 321 |
+
price_change = ((current_price - entry_price) / entry_price) * 100
|
| 322 |
+
price_change_display = f"{price_change:+.2f}%"
|
| 323 |
+
except (TypeError, ZeroDivisionError):
|
| 324 |
+
price_change_display = "N/A"
|
| 325 |
|
| 326 |
indicators_summary = format_technical_indicators(processed_data.get('advanced_indicators', {}))
|
| 327 |
pattern_summary = self._format_pattern_analysis(pattern_analysis)
|