uptest / glms.py
lykeven's picture
first model commit
54abf22
raw
history blame contribute delete
No virus
1 kB
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoModelForMultipleChoice, AutoModel
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from datasets import load_dataset
from datasets import load_metric
import torchsnooper
model_name = "THUDM/chatglm-6b"
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
print(response)