tmmdev commited on
Commit
805f7bd
·
verified ·
1 Parent(s): 5c43f7a

Update pattern_analyzer.py

Browse files
Files changed (1) hide show
  1. pattern_analyzer.py +22 -22
pattern_analyzer.py CHANGED
@@ -2,29 +2,30 @@ import os
2
  os.environ['HF_HOME'] = '/tmp/huggingface'
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
 
 
5
 
6
  class PatternAnalyzer:
7
  def __init__(self):
8
- # Check if CUDA is available
9
- if torch.cuda.is_available():
10
- self.model = AutoModelForCausalLM.from_pretrained(
11
- "tmmdev/codellama-pattern-analysis",
12
- load_in_8bit=True,
13
- device_map="auto",
14
- torch_dtype="auto"
15
- )
16
- else:
17
- # CPU fallback configuration
18
- self.model = AutoModelForCausalLM.from_pretrained(
19
- "tmmdev/codellama-pattern-analysis",
20
- device_map="auto",
21
- torch_dtype=torch.float32,
22
- low_cpu_mem_usage=True
23
- )
 
24
 
25
- self.tokenizer = AutoTokenizer.from_pretrained("tmmdev/codellama-pattern-analysis")
26
-
27
-
28
  self.basic_patterns = {
29
  'channel': {'min_points': 4, 'confidence_threshold': 0.7},
30
  'triangle': {'min_points': 3, 'confidence_threshold': 0.75},
@@ -36,7 +37,7 @@ class PatternAnalyzer:
36
  self.pattern_logic = PatternLogic()
37
 
38
  def analyze_data(self, ohlcv_data):
39
- data_prompt = f"""TASK: Identify high-confidence technical patterns only.
40
  Minimum confidence threshold: 0.8
41
  Required pattern criteria:
42
  1. Channel: Must have at least 3 touching points
@@ -54,7 +55,7 @@ class PatternAnalyzer:
54
  analysis = self.tokenizer.decode(outputs[0])
55
 
56
  return self.parse_analysis(analysis)
57
-
58
  def parse_analysis(self, analysis_text):
59
  try:
60
  json_start = analysis_text.find('{')
@@ -66,7 +67,6 @@ class PatternAnalyzer:
66
 
67
  for pattern in analysis_data.get('patterns', []):
68
  pattern_type = pattern.get('type')
69
-
70
  if pattern_type in self.basic_patterns:
71
  threshold = self.basic_patterns[pattern_type]['confidence_threshold']
72
  if pattern.get('confidence', 0) >= threshold:
 
2
  os.environ['HF_HOME'] = '/tmp/huggingface'
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ import json
6
+ import pandas as pd
7
+ from pattern_logic import PatternLogic
8
 
9
  class PatternAnalyzer:
10
  def __init__(self):
11
+ model_kwargs = {
12
+ "device_map": "auto",
13
+ "torch_dtype": torch.float32,
14
+ "low_cpu_mem_usage": True,
15
+ "max_memory": {"cpu": "4GB"},
16
+ "offload_folder": "/tmp/offload"
17
+ }
18
+
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ "tmmdev/codellama-pattern-analysis",
21
+ **model_kwargs
22
+ )
23
+
24
+ self.tokenizer = AutoTokenizer.from_pretrained(
25
+ "tmmdev/codellama-pattern-analysis",
26
+ use_fast=True
27
+ )
28
 
 
 
 
29
  self.basic_patterns = {
30
  'channel': {'min_points': 4, 'confidence_threshold': 0.7},
31
  'triangle': {'min_points': 3, 'confidence_threshold': 0.75},
 
37
  self.pattern_logic = PatternLogic()
38
 
39
  def analyze_data(self, ohlcv_data):
40
+ data_prompt = f"""TASK: Identify high-confidence technical patterns only.
41
  Minimum confidence threshold: 0.8
42
  Required pattern criteria:
43
  1. Channel: Must have at least 3 touching points
 
55
  analysis = self.tokenizer.decode(outputs[0])
56
 
57
  return self.parse_analysis(analysis)
58
+
59
  def parse_analysis(self, analysis_text):
60
  try:
61
  json_start = analysis_text.find('{')
 
67
 
68
  for pattern in analysis_data.get('patterns', []):
69
  pattern_type = pattern.get('type')
 
70
  if pattern_type in self.basic_patterns:
71
  threshold = self.basic_patterns[pattern_type]['confidence_threshold']
72
  if pattern.get('confidence', 0) >= threshold: