apak commited on
Commit
cc64dce
·
verified ·
1 Parent(s): 40da81b

Update ner_logic.py

Browse files
Files changed (1) hide show
  1. ner_logic.py +30 -59
ner_logic.py CHANGED
@@ -7,94 +7,65 @@ from labels_config import CUSTOM_LABELS
7
  from prompts import SYSTEM_COT_PROMPT
8
  import spaces
9
 
10
- # Dil ayarını Türkçe yapalım
11
  wikipedia.set_lang("tr")
12
 
13
- # 1. NER Modeli (Standart etiketler için)
14
- ner_pipe = pipeline(
15
- "ner",
16
- model="xlm-roberta-large-finetuned-conll03-english",
17
- aggregation_strategy="simple",
18
- device=0 if torch.cuda.is_available() else -1
19
- )
20
 
21
- # 2. LLM Modeli ve Tokenizer (Muhakeme için)
22
- model_name = "Qwen/Qwen2.5-1.5B-Instruct"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- llm_model = AutoModelForCausalLM.from_pretrained(
25
- model_name,
26
- dtype="auto",
27
- device_map="auto"
28
- )
29
 
30
  def get_wiki_summary(query):
31
- """Wikipedia'dan varlık hakkında kısa özet getirir."""
32
  try:
33
- # En yakın başlığı bul
34
  search_results = wikipedia.search(query)
35
- if not search_results:
36
- return None
37
- # İlk sonucun özetini al
38
- summary = wikipedia.summary(search_results[0], sentences=2, auto_suggest=False)
39
- return summary
40
- except:
41
- return None
42
 
43
  @spaces.GPU
44
  def refine_labels_batch(misc_items, full_sentence):
45
- """MISC varlıkları LLM ile yeniden analiz eder."""
46
- if not misc_items:
47
- return []
48
 
49
- # Label kurallarını metne dönüştür
50
  label_rules = ""
51
  for k, v in CUSTOM_LABELS.items():
52
  label_rules += f"### {k}\nTANIM: {v['tanim']}\nANAHTARLAR: {', '.join(v['anahtar_kelimeler'])}\n\n"
53
 
54
- # Analiz edilecek varlıkları metne dönüştür
55
  targets_text = "".join([f"- VARLIK: {item['word']} | WIKI: {item['wiki']}\n" for item in misc_items])
56
-
57
- final_prompt = SYSTEM_COT_PROMPT.format(
58
- label_rules=label_rules,
59
- full_sentence=full_sentence,
60
- targets=targets_text
61
- )
62
 
63
  messages = [{"role": "user", "content": final_prompt}]
64
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
65
  model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
66
 
67
  try:
68
- generated_ids = llm_model.generate(
69
- model_inputs.input_ids,
70
- max_new_tokens=1500,
71
- do_sample=False
72
- )
73
  output = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
74
 
75
- # JSON bloğunu ayıkla
76
  json_match = re.search(r'\[\s*\{.*\}\s*\]', output, re.DOTALL)
77
-
78
  if json_match:
79
  raw_json = json_match.group(0).strip()
80
-
81
- # Basit JSON tamamlama
82
- if raw_json.count('{') > raw_json.count('}'):
83
- raw_json += "}]"
84
-
85
  results = json.loads(raw_json)
86
-
87
- # app.py'de KeyError almamak için anahtar kontrolü
88
  for r in results:
89
- if 'karar' not in r: r['karar'] = "MISC"
90
- if 'reasoning' not in r: r['reasoning'] = "Analiz süreci tamamlanamadı."
91
- if 'varlik' not in r: r['varlik'] = "Bilinmeyen"
92
-
93
  return results
94
- else:
95
- print(f"LLM Yanlış Format Döndürdü: {output}")
96
- return []
97
-
98
  except Exception as e:
99
- print(f"Süreç Hatası: {str(e)}")
100
  return []
 
7
  from prompts import SYSTEM_COT_PROMPT
8
  import spaces
9
 
 
10
  wikipedia.set_lang("tr")
11
 
12
+ # Global değişkenler (Başlangıçta None, GPU fonksiyonu içinde yüklenecek)
13
+ _tokenizer = None
14
+ _llm_model = None
15
+ _ner_pipe = None
 
 
 
16
 
17
+ def load_models():
18
+ global _tokenizer, _llm_model, _ner_pipe
19
+ if _tokenizer is None:
20
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
21
+ _tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ _llm_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
23
+ _ner_pipe = pipeline("ner", model="xlm-roberta-large-finetuned-conll03-english", aggregation_strategy="simple", device=0)
24
+ return _tokenizer, _llm_model, _ner_pipe
25
 
26
  def get_wiki_summary(query):
 
27
  try:
 
28
  search_results = wikipedia.search(query)
29
+ if not search_results: return None
30
+ return wikipedia.summary(search_results[0], sentences=2, auto_suggest=False)
31
+ except: return None
32
+
33
+ def ner_pipe(text):
34
+ _, _, pipe = load_models()
35
+ return pipe(text)
36
 
37
  @spaces.GPU
38
  def refine_labels_batch(misc_items, full_sentence):
39
+ if not misc_items: return []
40
+ tokenizer, llm_model, _ = load_models()
 
41
 
 
42
  label_rules = ""
43
  for k, v in CUSTOM_LABELS.items():
44
  label_rules += f"### {k}\nTANIM: {v['tanim']}\nANAHTARLAR: {', '.join(v['anahtar_kelimeler'])}\n\n"
45
 
 
46
  targets_text = "".join([f"- VARLIK: {item['word']} | WIKI: {item['wiki']}\n" for item in misc_items])
47
+ final_prompt = SYSTEM_COT_PROMPT.format(label_rules=label_rules, full_sentence=full_sentence, targets=targets_text)
 
 
 
 
 
48
 
49
  messages = [{"role": "user", "content": final_prompt}]
50
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
  model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
52
 
53
  try:
54
+ generated_ids = llm_model.generate(model_inputs.input_ids, max_new_tokens=1000, do_sample=False)
 
 
 
 
55
  output = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
56
 
 
57
  json_match = re.search(r'\[\s*\{.*\}\s*\]', output, re.DOTALL)
 
58
  if json_match:
59
  raw_json = json_match.group(0).strip()
60
+ if raw_json.count('{') > raw_json.count('}'): raw_json += "}]"
 
 
 
 
61
  results = json.loads(raw_json)
62
+ # KeyError koruması: app.py'nin beklediği anahtarları garanti et
 
63
  for r in results:
64
+ r.setdefault('karar', 'MISC')
65
+ r.setdefault('reasoning', 'Analiz adımları oluşturulamadı.')
66
+ r.setdefault('varlik', 'Bilinmeyen')
 
67
  return results
68
+ return []
 
 
 
69
  except Exception as e:
70
+ print(f"Hata: {e}")
71
  return []