Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
import torch | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain_huggingface import HuggingFacePipeline | |
# Initialize HF pipeline for text generation | |
text_generator = pipeline( | |
"text-generation", # Task type | |
model="google/gemma-3n-E2B-it", | |
# model="google/gemma-3n-e4b-it", | |
# model="Qwen/Qwen3-Embedding-0.6B", | |
# device="cuda" if torch.cuda.is_available() else "cpu", | |
device= "cpu", | |
torch_dtype=torch.bfloat16, | |
max_new_tokens=500 # Limit output length | |
) | |
# Wrap pipeline for LangChain compatibility | |
model = HuggingFacePipeline(pipeline=text_generator) | |
def generate_sentences(topic, n=1): | |
prompt = ChatPromptTemplate.from_template( | |
"You are a helpful assistant. Generate exactly {n} simple sentences about the topic: {topic}. " | |
"Each sentence must be in English and appropriate for all audiences. " | |
"Return each sentence on a new line without any numbering or bullets" | |
) | |
chain = prompt | model | StrOutputParser() | |
response = chain.invoke({"topic": topic, "n": n}) | |
# Enhanced filtering | |
return [ | |
line.strip() for line in response.splitlines() | |
if (line.strip() | |
and not line.startswith(("###", "Instruction", "Output Format")) | |
and len(line.split()) <= 15 # Word limit enforcement | |
) | |
][:n] | |