langflow-owt / inference.py
chumengl's picture
upload model
d443994
raw
history blame
1.97 kB
"""Simple inference script to test the HuggingFace LangFlow model."""
import argparse
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
def main():
parser = argparse.ArgumentParser(description="Generate samples with LangFlow")
parser.add_argument(
"--model_path", type=str, default="hf_release/model_weights",
help="Path to the HuggingFace model directory")
parser.add_argument(
"--num_samples", type=int, default=5,
help="Number of samples to generate")
parser.add_argument(
"--num_steps", type=int, default=128,
help="Number of denoising steps")
parser.add_argument(
"--seq_length", type=int, default=1024,
help="Sequence length")
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Set seed for reproducibility
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForMaskedLM.from_pretrained(
args.model_path,
trust_remote_code=True
)
model = model.to(device)
model.eval()
print(f"\nGenerating {args.num_samples} samples with {args.num_steps} steps...")
with torch.no_grad():
samples = model.generate_samples(
num_samples=args.num_samples,
seq_length=args.seq_length,
num_steps=args.num_steps,
device=device
)
texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
for i, text in enumerate(texts):
print(f"\n--- Sample {i+1} ---")
# Print first 500 characters to keep output manageable
print(text[:500] + ("..." if len(text) > 500 else ""))
if __name__ == "__main__":
main()