EL GHAFRAOUI AYOUB
C
3a5be9b
raw
history blame contribute delete
3.86 kB
import logging
from transformers import pipeline
import asyncio
class TextGenerationHandler:
def __init__(self):
# Initialize the text generation pipeline
self.pipe = pipeline(
"text2text-generation",
model="google/flan-t5-small",
max_length=2048, # Increase max length for longer responses
num_return_sequences=1
)
self.logger = logging.getLogger(__name__)
async def generate_response(self, prompt: str) -> str:
"""
Generate a complete response using the T5 model pipeline
Args:
prompt (str): Input text to generate from
Returns:
str: Generated text output
"""
try:
# Break down the generation into sections for better coherence
sections = [
"1. Executive Summary",
"2. Project Scope and Objectives",
"3. Architecture Overview",
"4. Component Design",
"5. Security and Compliance",
"6. Deployment Strategy",
"7. Team Requirements",
"8. Cost Estimates",
"9. Project Timeline"
]
complete_response = []
for section in sections:
section_prompt = f"{prompt}\nGenerate content for: {section}"
self.logger.info(f"Generating section: {section}")
self.logger.debug(f"Section prompt: {section_prompt}")
output = self.pipe(
section_prompt,
max_length=512,
do_sample=True,
temperature=0.7,
repetition_penalty=1.2,
no_repeat_ngram_size=3
)
section_text = output[0]['generated_text'].strip()
self.logger.info(f"Generated text for {section}:\n{section_text}\n")
complete_response.append(f"{section}\n{section_text}")
final_response = "\n\n".join(complete_response)
self.logger.info(f"Complete response:\n{final_response}")
return final_response
except Exception as e:
self.logger.error(f"Error generating text: {str(e)}", exc_info=True)
raise
async def stream_response(self, prompt: str):
"""
Stream the generated response section by section
Args:
prompt (str): Input text to generate from
Yields:
dict: Response chunks with type and content
"""
try:
# Generate complete response first
response = await self.generate_response(prompt)
# Stream each section
accumulated_response = ""
sections = response.split('\n\n')
for section in sections:
accumulated_response += section + "\n\n"
self.logger.debug(f"Streaming section:\n{section}\n")
yield {
"type": "content",
"content": accumulated_response.strip()
}
await asyncio.sleep(0.1)
except Exception as e:
self.logger.error(f"Error in stream_response: {str(e)}", exc_info=True)
yield {
"type": "error",
"content": str(e)
}
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('f5_model.log')
]
)
f5_model = TextGenerationHandler()