File size: 1,954 Bytes
8d1b471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from QBModelConfig import QBModelConfig
from QBModelWrapper import QBModelWrapper
from transformers import AutoConfig, AutoModel, AutoModelForQuestionAnswering
import torch 
import numpy as np
from transformers import QuestionAnsweringPipeline
from transformers import PretrainedConfig
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering
from transformers import pipeline
from QAPipeline import QApipeline

AutoConfig.register("TFIDF-QA", QBModelConfig)
AutoModel.register(QBModelConfig, QBModelWrapper)
AutoModelForQuestionAnswering.register(QBModelConfig, QBModelWrapper)

QBModelConfig.register_for_auto_class()
QBModelWrapper.register_for_auto_class("AutoModel")
QBModelWrapper.register_for_auto_class("AutoModelForQuestionAnswering")



#qbmodel_config.save_pretrained("model-config")
#qbmodel.save_pretrained(save_directory='TriviaAnsweringMachine8', safe_serialization= False, push_to_hub=True)
#print(qbmodel.config.torch_dtype.split(".")[1])
from huggingface_hub import Repository

repo = Repository("/mnt/c/Users/backe/Documents/GitHub/TriviaAnsweringMachine/")
repo.push_to_hub("TriviaAnsweringMachine10")

#qbmodel_config = QBModelConfig()
#qbmodel = QBModelWrapper(qbmodel_config)

#qbmodel.push_to_hub("TriviaAnsweringMachine10", safe_serialization=False)

#model = AutoModelForQuestionAnswering.from_pretrained("backedman/TriviaAnsweringMachine6", config=QBModelConfig(), trust_remote_code = True)
#tokenizer = AutoTokenizer.from_pretrained(model_name)


PIPELINE_REGISTRY.register_pipeline(
     "demo-qa",
     pipeline_class=QApipeline,
     pt_model=AutoModelForQuestionAnswering,
     tf_model=TFAutoModelForQuestionAnswering,
)

qa_pipe = pipeline("demo-qa", model="backedman/TriviaAnsweringMachine10", tokenizer="backedman/TriviaAnsweringMachine10")
qa_pipe.push_to_hub("TriviaAnsweringMachineREAL", safe_serialization=False)
#qa_pipe("test")