from typing import Dict, List, Any import logging from transformers import pipeline import os import re import json import torch import requests import zipfile logger = logging.getLogger("handler.py") class LawLookup: def __init__(self, json_file: str): self.json_file = json_file self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip' if not os.path.exists(self.json_file): self._download_and_extract_zip() with open(self.json_file, 'r', encoding='utf-8-sig') as file: self.laws_data = json.load(file) self.laws_dict = self._create_laws_dict() def _download_and_extract_zip(self): zip_path = 'ChLaw.zip' # Download the zip file response = requests.get(self.zip_url) with open(zip_path, 'wb') as file: file.write(response.content) # Extract only the ChLaw.json file from the zip with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extract('ChLaw.json') # Remove the zip file after extraction os.remove(zip_path) def _create_laws_dict(self): laws_dict = {} for law in self.laws_data['Laws']: law_name = law['LawName'] articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent'] for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None} laws_dict[law_name] = articles return laws_dict def _extract_article_no(self, article_no_str): try: # Extract the numeric part of the article number return article_no_str.replace('第', '').replace('條', '').strip() except ValueError: return None def get_law(self, law_name: str, article_no: str) -> str: article_no = str(article_no) if law_name in self.laws_dict: if article_no in self.laws_dict[law_name]: return self.laws_dict[law_name][article_no] else: return "Article not found." else: return "Law not found." def get_law_from_token(self, token: str) -> str: if "|" not in token: return None if token[0] != "<" and token[-1] != ">": return {} token = token[1:-1] law_name, article_no = token.split("|")[:2] return { "token": token, "lawName": law_name, "articleNo": article_no, "content": self.get_law(law_name, article_no)} class EndpointHandler(): def __init__(self, path=""): self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16) self.model = self.pipeline.model self.tokenizer = self.pipeline.tokenizer self.law_lookup = LawLookup('ChLaw.json') self.vocab = self.pipeline.tokenizer.get_vocab() law_tokens = {} for k, v in self.vocab.items(): if k.startswith("<") and len(k)>1 and k.find("|")>1: law_tokens[k] = v self.law_token_ids = list(law_tokens.values()) self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids) def __call__( self, query: Dict[str, Any] ) -> List[Dict[str, Any]]: max_new_tokens=5 do_sample=False topk=10 base_lambda=1. inputs = query.pop("inputs", query) if not inputs.endswith(""): inputs += "" logger.info(inputs) inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = self.model(**inputs) outputs_logits = outputs.logits[0, -1, self.law_token_ids] base_input = self.tokenizer("", return_tensors="pt").to("cuda") with torch.no_grad(): base_output = self.model(**base_input) base_logits = base_output.logits[0, -1, self.law_token_ids] raw_mean = outputs_logits.mean() outputs_logits = outputs_logits - base_lambda * base_logits outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean()) law_token_probs = outputs_logits.softmax(dim=0) sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk] logger.info([self.law_token_names[x] for x in sorted_ids]) token_objects = [ self.law_lookup.get_law_from_token(self.law_token_names[x]) for x in sorted_ids.tolist()] return {"tokens": token_objects}