bertugmirasyedi commited on
Commit
3316ef5
1 Parent(s): 1ec7e79

Changed summarization model and added onnxruntime options

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. __pycache__/app.cpython-310.pyc +0 -0
  3. app.py +45 -45
.DS_Store ADDED
Binary file (6.15 kB). View file
 
__pycache__/app.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
app.py CHANGED
@@ -21,7 +21,7 @@ def search(
21
  classification: bool = True,
22
  summarization: bool = True,
23
  similarity: bool = False,
24
- add_chatgpt_results: bool = True,
25
  n_results: int = 10,
26
  ):
27
  import time
@@ -316,7 +316,7 @@ def search(
316
 
317
  return similar_books
318
 
319
- def summarize(descriptions):
320
  """
321
  Summarize the descriptions and return the results.
322
  """
@@ -325,10 +325,17 @@ def search(
325
  AutoModelForSeq2SeqLM,
326
  pipeline,
327
  )
 
 
328
 
329
  # Define the summarizer model and tokenizer
330
- tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
331
- model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")
 
 
 
 
 
332
 
333
  # Create the summarizer pipeline
334
  summarizer_pipe = pipeline(
@@ -349,7 +356,7 @@ def search(
349
 
350
  return summaries
351
 
352
- def classify(combined_data, parallel=False):
353
  """
354
  Create classifier pipeline and return the results.
355
  """
@@ -358,15 +365,25 @@ def search(
358
  AutoModelForSequenceClassification,
359
  pipeline,
360
  )
 
 
361
 
362
- # Define the zero-shot classifier
363
- tokenizer = AutoTokenizer.from_pretrained(
364
- "sileod/deberta-v3-base-tasksource-nli"
365
- )
 
 
 
 
 
 
 
 
 
 
 
366
 
367
- model = AutoModelForSequenceClassification.from_pretrained(
368
- "sileod/deberta-v3-base-tasksource-nli"
369
- )
370
  classifier_pipe = pipeline(
371
  "zero-shot-classification",
372
  model=model,
@@ -374,49 +391,30 @@ def search(
374
  hypothesis_template="This book is {}.",
375
  batch_size=1,
376
  device=-1,
377
- multi_label=True,
378
  )
379
 
380
  # Define the candidate labels
381
- candidate_labels = [
382
  "Introductory",
383
  "Advanced",
384
- "Academic",
385
- "Not Academic",
386
- "Manual",
387
  ]
388
 
389
- if parallel:
390
- import ray
391
- import psutil
392
-
393
- # Define the number of cores to use
394
- num_cores = psutil.cpu_count(logical=True)
395
 
396
- # Initialize Ray
397
- ray.init(num_cpus=num_cores, ignore_reinit_error=True)
398
- classifier_id = ray.put(classifier_pipe)
399
-
400
- # Define the function to be parallelized
401
- @ray.remote
402
- def classify_parallel(classifier_id, doc, candidate_labels):
403
- classifier = ray.get(classifier_id)
404
- return classifier(doc, candidate_labels)
405
-
406
- # Get the predicted labels
407
- classes = [
408
- classify_parallel.remote(classifier_id, doc, candidate_labels)
409
- for doc in combined_data
410
- ]
411
- else:
412
- # Get the predicted labels
413
- classes = [classifier_pipe(doc, candidate_labels) for doc in combined_data]
414
 
415
  return classes
416
 
417
  # If true then run the similarity, summarize, and classify functions
418
  if classification:
419
- classes = classify(combined_data, parallel=False)
420
  else:
421
  classes = [
422
  {"labels": ["No labels available."], "scores": [0]}
@@ -428,7 +426,7 @@ def search(
428
  classification_time = int(fourth_checkpoint - third_checkpoint)
429
 
430
  if summarization:
431
- summaries = summarize(descriptions)
432
  else:
433
  summaries = [
434
  [{"summary_text": description}]
@@ -467,8 +465,10 @@ def search(
467
  "author": authors[i],
468
  "publisher": publishers[i],
469
  "image_link": images[i],
470
- "labels": classes[i]["labels"][0:2],
471
- "label_confidences": classes[i]["scores"][0:2],
 
 
472
  "summary": summaries[i][0]["summary_text"],
473
  "similar_books": similar_books[i]["sorted_by_similarity"],
474
  "runtime": {
 
21
  classification: bool = True,
22
  summarization: bool = True,
23
  similarity: bool = False,
24
+ add_chatgpt_results: bool = False,
25
  n_results: int = 10,
26
  ):
27
  import time
 
316
 
317
  return similar_books
318
 
319
+ def summarize(descriptions, runtime="normal"):
320
  """
321
  Summarize the descriptions and return the results.
322
  """
 
325
  AutoModelForSeq2SeqLM,
326
  pipeline,
327
  )
328
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
329
+ from optimum.bettertransformer import BetterTransformer
330
 
331
  # Define the summarizer model and tokenizer
332
+ if runtime == "normal":
333
+ tokenizer = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
334
+ model = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
335
+ model = BetterTransformer.transform(model)
336
+ elif runtime == "onnxruntime":
337
+ tokenizer = AutoTokenizer.from_pretrained("optimum/t5-small")
338
+ model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
339
 
340
  # Create the summarizer pipeline
341
  summarizer_pipe = pipeline(
 
356
 
357
  return summaries
358
 
359
+ def classify(combined_data, runtime="normal"):
360
  """
361
  Create classifier pipeline and return the results.
362
  """
 
365
  AutoModelForSequenceClassification,
366
  pipeline,
367
  )
368
+ from optimum.onnxruntime import ORTModelForSequenceClassification
369
+ from optimum.bettertransformer import BetterTransformer
370
 
371
+ if runtime == "normal":
372
+ # Define the zero-shot classifier
373
+ tokenizer = AutoTokenizer.from_pretrained(
374
+ "sileod/deberta-v3-base-tasksource-nli"
375
+ )
376
+ model = AutoModelForSequenceClassification.from_pretrained(
377
+ "sileod/deberta-v3-base-tasksource-nli"
378
+ )
379
+ elif runtime == "onnxruntime":
380
+ tokenizer = AutoTokenizer.from_pretrained(
381
+ "optimum/distilbert-base-uncased-mnli"
382
+ )
383
+ model = ORTModelForSequenceClassification.from_pretrained(
384
+ "optimum/distilbert-base-uncased-mnli"
385
+ )
386
 
 
 
 
387
  classifier_pipe = pipeline(
388
  "zero-shot-classification",
389
  model=model,
 
391
  hypothesis_template="This book is {}.",
392
  batch_size=1,
393
  device=-1,
394
+ multi_label=False,
395
  )
396
 
397
  # Define the candidate labels
398
+ level = [
399
  "Introductory",
400
  "Advanced",
 
 
 
401
  ]
402
 
403
+ audience = ["Academic", "Not Academic", "Manual"]
 
 
 
 
 
404
 
405
+ classes = [
406
+ {
407
+ "audience": classifier_pipe(doc, audience),
408
+ "level": classifier_pipe(doc, level),
409
+ }
410
+ for doc in combined_data
411
+ ]
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  return classes
414
 
415
  # If true then run the similarity, summarize, and classify functions
416
  if classification:
417
+ classes = classify(combined_data, runtime="normal")
418
  else:
419
  classes = [
420
  {"labels": ["No labels available."], "scores": [0]}
 
426
  classification_time = int(fourth_checkpoint - third_checkpoint)
427
 
428
  if summarization:
429
+ summaries = summarize(descriptions, runtime="normal")
430
  else:
431
  summaries = [
432
  [{"summary_text": description}]
 
465
  "author": authors[i],
466
  "publisher": publishers[i],
467
  "image_link": images[i],
468
+ "audience": classes[i]["audience"]["labels"][0],
469
+ "audience_confidence": classes[i]["audience"]["scores"][0],
470
+ "level": classes[i]["level"]["labels"][0],
471
+ "level_confidence": classes[i]["level"]["scores"][0],
472
  "summary": summaries[i][0]["summary_text"],
473
  "similar_books": similar_books[i]["sorted_by_similarity"],
474
  "runtime": {