Spaces:
Sleeping
Sleeping
| from typing import Optional, List | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS | |
| class LocalLLM: | |
| def __init__(self): | |
| self.pipe = None | |
| self.model_id = None | |
| self._load_any() | |
| def _load_any(self): | |
| for mid in OPEN_LLM_CANDIDATES: | |
| try: | |
| tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| mid, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok) | |
| self.model_id = mid | |
| return | |
| except Exception: | |
| continue | |
| self.pipe = None | |
| def chat(self, prompt: str) -> Optional[str]: | |
| if not self.pipe: | |
| return None | |
| try: | |
| out = self.pipe( | |
| prompt, | |
| max_new_tokens=LOCAL_MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.9, | |
| repetition_penalty=1.12, | |
| eos_token_id=self.pipe.tokenizer.eos_token_id | |
| ) | |
| text = out[0]["generated_text"] | |
| # Return only the continuation if prompt is included | |
| return text[len(prompt):].strip() if text.startswith(prompt) else text.strip() | |
| except Exception: | |
| return None | |