Spaces:
Runtime error
Runtime error
| import langid | |
| import os | |
| from haystack import Pipeline | |
| from haystack.nodes import TextConverter, PreProcessor, BM25Retriever, FARMReader | |
| from haystack.document_stores import InMemoryDocumentStore | |
| from haystack.utils import print_answers | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| class Sejarah: | |
| def __init__(self): | |
| document_store = InMemoryDocumentStore(use_bm25=True) | |
| #initialize the pipeline | |
| indexing_pipeline = Pipeline() | |
| text_converter = TextConverter() | |
| preprocessor = PreProcessor( | |
| clean_whitespace=True, | |
| clean_header_footer=True, | |
| clean_empty_lines=True, | |
| split_by="word", | |
| split_length=200, | |
| split_overlap=20, | |
| split_respect_sentence_boundary=True, | |
| ) | |
| indexing_pipeline.add_node(component=text_converter, name="TextConverter", inputs=["File"]) | |
| indexing_pipeline.add_node(component=preprocessor, name="PreProcessor", inputs=["TextConverter"]) | |
| indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["PreProcessor"]) | |
| dir = "documents" | |
| files_to_index = [dir+"/" + f for f in os.listdir(dir)] | |
| indexing_pipeline.run_batch(file_paths=files_to_index) | |
| retriever = BM25Retriever(document_store=document_store) | |
| reader = FARMReader(model_name_or_path="primasr/malaybert-for-eqa-finetuned", use_gpu=True) | |
| self.querying_pipeline = Pipeline() | |
| self.querying_pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) | |
| self.querying_pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"]) | |
| #Malay to English Model | |
| self.id_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-id-en") | |
| self.id_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-id-en") | |
| #English to Malay Model | |
| self.en_id_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-id") | |
| self.en_id_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-id") | |
| def language_converter(self, content, lang, method): | |
| if lang == "en": | |
| if method == "question": | |
| tokenized_text = self.en_id_tokenizer.prepare_seq2seq_batch([content], return_tensors='pt') | |
| translation = self.en_id_model.generate(**tokenized_text) | |
| content = self.en_id_tokenizer.batch_decode(translation, skip_special_tokens=True)[0] | |
| else: | |
| tokenized_text = self.id_en_tokenizer.prepare_seq2seq_batch([content], return_tensors='pt') | |
| translation = self.id_en_model.generate(**tokenized_text) | |
| content = self.id_en_tokenizer.batch_decode(translation, skip_special_tokens=True)[0] | |
| return content | |
| def detect_language(self, content): | |
| lang = langid.classify(content) | |
| return lang[0] | |
| def interface(self, question): | |
| language = self.detect_language(question) | |
| converted_question = self.language_converter(question, language, "question") | |
| result = self.querying_pipeline.run( | |
| query=converted_question, | |
| params={ | |
| "Retriever": {"top_k": 10}, | |
| "Reader": {"top_k": 5} | |
| } | |
| ) | |
| answer = self.language_converter(result['answers'][0].answer, language, "answer") | |
| context = self.language_converter(result['answers'][0].context, language, "answer") | |
| return answer, context |