import torch from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM def setup_llm(): """Set up a more capable LLM for CSV analysis.""" try: # Try to load FLAN-T5-small, which is better for instruction following # while still being relatively small (~300MB) model_name = "google/flan-t5-small" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) generator = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=512 ) # Create a wrapper class that matches the expected interface class FlanT5LLM: def complete(self, prompt): class Response: def __init__(self, text): self.text = text try: # For FLAN-T5, we don't need to strip the prompt from the output result = generator(prompt, max_length=150, do_sample=False)[0] response_text = result["generated_text"].strip() if not response_text: response_text = "I couldn't generate a proper response." return Response(response_text) except Exception as e: print(f"Error generating response: {e}") return Response(f"Error generating response: {str(e)}") return FlanT5LLM() except Exception as e: print(f"Error setting up FLAN-T5 model: {e}") # Fallback to a simpler model if FLAN-T5 fails try: # Try T5-small as a fallback model_name = "t5-small" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) generator = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=512 ) class T5LLM: def complete(self, prompt): class Response: def __init__(self, text): self.text = text try: result = generator(prompt, max_length=150, do_sample=False)[0] return Response(result["generated_text"].strip()) except Exception as e: return Response(f"Error: {str(e)}") return T5LLM() except Exception as e2: print(f"Error setting up fallback model: {e2}") # Last resort - dummy LLM class DummyLLM: def complete(self, prompt): class Response: def __init__(self, text): self.text = text return Response("Model initialization failed. Please check logs.") return DummyLLM()