CHEMISTral7Bv0.3 / mistral_chat_script.py
Clemspace's picture
added inference + api wrapper
32fe622
raw
history blame contribute delete
No virus
1.3 kB
import sys
from pathlib import Path
from mistral_inference.generate import generate
from mistral_inference.model import Transformer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None):
# Find the correct tokenizer file
model_path = Path(model_path)
tokenizer_file = model_path / "tokenizer.model.v3"
if not tokenizer_file.is_file():
raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}")
mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file))
tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
transformer = Transformer.from_folder(
model_path, max_batch_size=3, num_pipeline_ranks=1
)
if lora_path is not None:
transformer.load_lora(Path(lora_path))
tokens = tokenizer.encode(prompt, bos=True, eos=False)
generated_tokens, _ = generate(
[tokens],
transformer,
max_tokens=max_tokens,
temperature=temperature,
eos_id=tokenizer.eos_id,
)
answer = tokenizer.decode(generated_tokens[0])
print(answer)
if __name__ == "__main__":
import fire
fire.Fire(run_chat)