mynewmodel / chat_cli.py
yopzey's picture
Committing all changes before LFS migration
9190d78
raw
history blame contribute delete
No virus
1.49 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def chat_with_model(model_path: str):
# Ensure CUDA is available and set the device to use the first GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Wrap the model with DataParallel to use multiple GPUs
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
print("You're now chatting with the model. Type 'quit' to exit.")
while True:
# Get user input
input_text = input("You: ")
if input_text.lower() == 'quit':
break
# Encode the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Generate a response
with torch.no_grad():
generated_text_samples = model.generate(input_ids, max_length=50, pad_token_id=tokenizer.eos_token_id)
# Decode and print the model's response
response_text = tokenizer.decode(generated_text_samples[0], skip_special_tokens=True)
print("AI:", response_text)
if __name__ == "__main__":
model_path = '/home/energyxadmin/UI2/merge'
chat_with_model(model_path)