Shago's picture
Update llm.py
ae8c668 verified
from transformers import pipeline
import torch
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_huggingface import HuggingFacePipeline
# Initialize HF pipeline for text generation
text_generator = pipeline(
"text-generation", # Task type
model="google/gemma-3n-E2B-it",
# model="google/gemma-3n-e4b-it",
# model="Qwen/Qwen3-Embedding-0.6B",
# device="cuda" if torch.cuda.is_available() else "cpu",
device= "cpu",
torch_dtype=torch.bfloat16,
max_new_tokens=500 # Limit output length
)
# Wrap pipeline for LangChain compatibility
model = HuggingFacePipeline(pipeline=text_generator)
def generate_sentences(topic, n=1):
prompt = ChatPromptTemplate.from_template(
"You are a helpful assistant. Generate exactly {n} simple sentences about the topic: {topic}. "
"Each sentence must be in English and appropriate for all audiences. "
"Return each sentence on a new line without any numbering or bullets"
)
chain = prompt | model | StrOutputParser()
response = chain.invoke({"topic": topic, "n": n})
# Enhanced filtering
return [
line.strip() for line in response.splitlines()
if (line.strip()
and not line.startswith(("###", "Instruction", "Output Format"))
and len(line.split()) <= 15 # Word limit enforcement
)
][:n]