|
|
|
"""Untitled1 (2).ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1W44vqcumLa_CtuLGpbS8dEk4WtCFUr-z |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Step 2""" |
|
|
|
from haystack.telemetry import tutorial_running |
|
|
|
tutorial_running(1) |
|
import logging |
|
|
|
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) |
|
logging.getLogger("haystack").setLevel(logging.INFO) |
|
|
|
from haystack.nodes import PreProcessor |
|
from haystack.utils import convert_files_to_docs |
|
|
|
all_docs = convert_files_to_docs(dir_path='/content/drive/MyDrive/data/malaysia/') |
|
preprocessor = PreProcessor( |
|
clean_empty_lines=True, |
|
clean_whitespace=True, |
|
clean_header_footer=False, |
|
split_by="word", |
|
split_length=100, |
|
split_respect_sentence_boundary=True, |
|
) |
|
docs = preprocessor.process(all_docs) |
|
|
|
print(f"n_files_input: {len(all_docs)}\nn_docs_output: {len(docs)}") |
|
|
|
from haystack.document_stores import InMemoryDocumentStore |
|
|
|
document_store = InMemoryDocumentStore(use_bm25=True) |
|
|
|
import os |
|
from haystack.pipelines.standard_pipelines import TextIndexingPipeline |
|
|
|
files_to_index = ['/content/drive/MyDrive/data/malaysia' + "/" + f for f in os.listdir('/content/drive/MyDrive/data/malaysia')] |
|
indexing_pipeline = TextIndexingPipeline(document_store) |
|
indexing_pipeline.run_batch(file_paths=files_to_index) |
|
|
|
from haystack.nodes import FARMReader |
|
from haystack.utils import fetch_archive_from_http |
|
|
|
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=True) |
|
data_dir = "data/squad20" |
|
|
|
reader.train(data_dir='/content/drive/MyDrive/data/malaysia', train_filename='ms-train-2.0.json', use_gpu=True, n_epochs=1, save_dir="MyCustomReader") |
|
|
|
reader.save(directory="/content/drive/MyDrive/data/malaysia/MyCustomReader") |
|
|
|
from haystack.nodes import BM25Retriever |
|
|
|
retriever = BM25Retriever(document_store=document_store) |
|
|
|
from haystack.nodes import TransformersReader |
|
|
|
new_reader = FARMReader(model_name_or_path="/content/drive/MyDrive/data/malaysia/MyCustomReader", use_gpu=True) |
|
|
|
|
|
from haystack.pipelines import ExtractiveQAPipeline |
|
|
|
pipe = ExtractiveQAPipeline(new_reader, retriever) |
|
|
|
prediction = pipe.run( |
|
query="siapakah najib razak", |
|
params={ |
|
"Retriever": {"top_k": 10}, |
|
"Reader": {"top_k": 5} |
|
} |
|
) |
|
|
|
prediction['answers'] |
|
|
|
from getpass import getpass |
|
|
|
model_api_key = getpass("Enter model provider API key:") |
|
|
|
import requests |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/yewsam1277/question-answering-bahasa-malaysia" |
|
headers = {"Authorization": "Bearer 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"} |
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
return response.json() |
|
|
|
output = query({ |
|
"inputs": { |
|
"question": "What's my name?", |
|
"context": "My name is Clara and I live in Berkeley." |
|
}, |
|
}) |
|
|
|
print(output) |
|
|
|
from haystack.nodes import PromptNode |
|
|
|
model_name = "yewsam1277/question-answering-bahasa-malaysia" |
|
prompt_node = PromptNode(model_name, api_key=model_api_key, max_length=256) |
|
|
|
from haystack.agents.memory import ConversationSummaryMemory |
|
|
|
summary_memory = ConversationSummaryMemory(prompt_node) |
|
|
|
from haystack.agents.conversational import ConversationalAgent |
|
|
|
conversational_agent = ConversationalAgent(prompt_node=prompt_node, memory=summary_memory) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Step 1""" |
|
|
|
import wikipediaapi |
|
|
|
wiki = wikipediaapi.Wikipedia('ms') |
|
page = wiki.page('Malaysia') |
|
|
|
pages = {'Malaysia': page} |
|
|
|
pages.update(page.links) |
|
|
|
len(pages) |
|
|
|
from google.colab import drive |
|
drive.mount('/content/drive') |
|
|
|
done = 0 |
|
for key in pages: |
|
try: |
|
with open(f'/content/drive/MyDrive/data/malaysia/{pages[key].title}.txt', 'w') as f: |
|
get_text = pages[key].text |
|
get_text = get_text.lower().replace('\n', ' ') |
|
f.write(get_text) |
|
except Exception as e: |
|
pass |
|
done += 1 |
|
print(f"Written: {pages[key].title}\t(done {done})", end='\r') |
|
|
|
len(pages) |
|
|
|
"""Training""" |
|
|
|
|