import shutil import os __import__('pysqlite3') import sys sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') from sentence_transformers import SentenceTransformer import chromadb from datasets import load_dataset import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments # Set environment variables to address warnings os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' torch.random.manual_seed(0) model_name = "microsoft/Phi-3-mini-4k-instruct-gguf" model = AutoModelForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype="auto", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_name) # Function to clear the cache def clear_cache(model_name): cache_dir = os.path.expanduser(f'~/.cache/torch/sentence_transformers/{model_name.replace("/", "_")}') if os.path.exists(cache_dir): shutil.rmtree(cache_dir) print(f"Cleared cache directory: {cache_dir}") else: print(f"No cache directory found for: {cache_dir}") # Embedding vector class VectorStore: def __init__(self, collection_name): try: self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1') except Exception as e: print(f"Error loading model: {e}") raise self.chroma_client = chromadb.Client() self.collection = self.chroma_client.create_collection(name=collection_name) def populate_vectors(self, dataset, batch_size=20): dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train') dataset = dataset.select(range(1500)) texts = [] i = 0 for example in dataset: title = example['title_cleaned'] recipe = example['recipe_new'] meal_type = example['meal_type'] allergy = example['allergy_type'] ingredients_alternative = example['ingredients_alternatives'] text = f"{title} {recipe} {meal_type} {allergy} {ingredients_alternative}" texts.append(text) if (i + 1) % batch_size == 0: self._process_batch(texts, i) texts = [] i += 1 if texts: self._process_batch(texts, i) def _process_batch(self, texts, batch_start_idx): embeddings = self.embedding_model.encode(texts, batch_size=len(texts)).tolist() for j, embedding in enumerate(embeddings): self.collection.add(embeddings=[embedding], documents=[texts[j]], ids=[str(batch_start_idx + j)]) def search_context(self, query, n_results=1): query_embeddings = self.embedding_model.encode(query).tolist() return self.collection.query(query_embeddings=query_embeddings, n_results=n_results) vector_store = VectorStore("embedding_vector") vector_store.populate_vectors(dataset=None) def fine_tune_model(): dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train') dataset = dataset.select(range(1500)) def tokenize_function(examples): return tokenizer( [" ".join([title, recipe]) for title, recipe in zip(examples['title_cleaned'], examples['recipe_new'])], padding="max_length", truncation=True ) tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=8) training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=4, per_device_eval_batch_size=4, num_train_epochs=3, weight_decay=0.01, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets, ) trainer.train() fine_tune_model() conversation_history = [] def chatbot_response(user_input): global conversation_history results = vector_store.search_context(user_input, n_results=1) context = results['documents'][0] if results['documents'] else "" conversation_history.append(f"User: {user_input}\nContext: {context[:150]}\nBot:") inputs = tokenizer("\n".join(conversation_history), return_tensors="pt") outputs = model.generate(**inputs, max_length=150, do_sample=True, temperature=0.7) response = tokenizer.decode(outputs[0], skip_special_tokens=True) conversation_history.append(response) return response def chat(user_input): response = chatbot_response(user_input) return response css = ".gradio-container {background: url(https://upload.wikimedia.org/wikipedia/commons/f/f5/Spring_Kitchen_Line-Up_%28Unsplash%29.jpg)}" iface = gr.Interface(fn=chat, inputs="text", outputs="text", css=css) iface.launch()