poyum commited on
Commit
4a3df21
·
1 Parent(s): f709e5e
Files changed (2) hide show
  1. app.py +13 -0
  2. pipeline.py +4 -4
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
3
+ from pipeline import DiscoursePipeline # ton code custom
4
+
5
+ model_id = "poyum/test_discut"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ pipe = DiscoursePipeline(model_id=model_id, tokenizer=tokenizer)
8
+
9
+ def predict(text):
10
+ return pipe(text)
11
+
12
+ demo = gr.Interface(fn=predict, inputs="text", outputs="text")
13
+ demo.launch()
pipeline.py CHANGED
@@ -35,15 +35,15 @@ def write_sentences_to_format(sentences: list[str], filename: str):
35
 
36
 
37
  class DiscoursePipeline(Pipeline):
38
- def __init__(self, model, tokenizer, config:dict, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
39
- auto_model = AutoModelForTokenClassification.from_pretrained(model)
40
  super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
41
- self.config = {"model_checkpoint": model, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
42
  "padding":"max_length",
43
  "truncation":True,
44
  "max_length": 512
45
  }}
46
- self.model = model
47
  self.output_folder = output_folder
48
 
49
  def _sanitize_parameters(self, **kwargs):
 
35
 
36
 
37
  class DiscoursePipeline(Pipeline):
38
+ def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
39
+ auto_model = AutoModelForTokenClassification.from_pretrained(model_id)
40
  super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
41
+ self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
42
  "padding":"max_length",
43
  "truncation":True,
44
  "max_length": 512
45
  }}
46
+ self.model = model_id
47
  self.output_folder = output_folder
48
 
49
  def _sanitize_parameters(self, **kwargs):