Update README.md
Browse files
README.md
CHANGED
@@ -51,32 +51,26 @@ import pandas as pd, numpy as np, warnings, torch, re
|
|
51 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
52 |
from bs4 import BeautifulSoup
|
53 |
warnings.filterwarnings("ignore", category=UserWarning, module='bs4')
|
54 |
-
|
55 |
# Helper Functions
|
56 |
def clean_and_parse_tweet(tweet):
|
57 |
tweet = re.sub(r"https?://\S+|www\.\S+", " URL ", tweet)
|
58 |
parsed = BeautifulSoup(tweet, "html.parser").get_text() if "filename" not in str(BeautifulSoup(tweet, "html.parser")) else None
|
59 |
return re.sub(r" +", " ", re.sub(r'^[.:]+', '', re.sub(r"\\n+|\n+", " ", parsed or tweet)).strip()) if parsed else None
|
60 |
-
|
61 |
def predict_tweet(tweet, model, tokenizer, device, threshold=0.5):
|
62 |
inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
63 |
probs = torch.sigmoid(model(**inputs).logits).detach().cpu().numpy()[0]
|
64 |
return probs, [id2label[i] for i, p in enumerate(probs) if id2label[i] in {'Product', 'Place', 'Price', 'Promotion'} and p >= threshold]
|
65 |
-
|
66 |
# Setup
|
67 |
device = "mps" if torch.backends.mps.is_built() and torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
68 |
synxp = "dmr76/mmx_classifier_microblog_ENv02"
|
69 |
model = AutoModelForSequenceClassification.from_pretrained(synxp).to(device)
|
70 |
tokenizer = AutoTokenizer.from_pretrained(synxp)
|
71 |
id2label = model.config.id2label
|
72 |
-
|
73 |
# ---->>> Define your Tweet <<<----
|
74 |
tweet = "Best cushioning ever!!! ๐ค๐ค๐ค my zoom vomeros are the bomb๐๐ฝโโ๏ธ๐จ!!! \n @nike #run #training https://randomurl.ai"
|
75 |
-
|
76 |
# Clean and Predict
|
77 |
cleaned_tweet = clean_and_parse_tweet(tweet)
|
78 |
probs, labels = predict_tweet(cleaned_tweet, model, tokenizer, device)
|
79 |
-
|
80 |
# Print Labels and Probabilities
|
81 |
print("Please don't forget to cite the paper: https://ssrn.com/abstract=4542949 in you use this code")
|
82 |
print(labels, probs)
|
|
|
51 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
52 |
from bs4 import BeautifulSoup
|
53 |
warnings.filterwarnings("ignore", category=UserWarning, module='bs4')
|
|
|
54 |
# Helper Functions
|
55 |
def clean_and_parse_tweet(tweet):
|
56 |
tweet = re.sub(r"https?://\S+|www\.\S+", " URL ", tweet)
|
57 |
parsed = BeautifulSoup(tweet, "html.parser").get_text() if "filename" not in str(BeautifulSoup(tweet, "html.parser")) else None
|
58 |
return re.sub(r" +", " ", re.sub(r'^[.:]+', '', re.sub(r"\\n+|\n+", " ", parsed or tweet)).strip()) if parsed else None
|
|
|
59 |
def predict_tweet(tweet, model, tokenizer, device, threshold=0.5):
|
60 |
inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
61 |
probs = torch.sigmoid(model(**inputs).logits).detach().cpu().numpy()[0]
|
62 |
return probs, [id2label[i] for i, p in enumerate(probs) if id2label[i] in {'Product', 'Place', 'Price', 'Promotion'} and p >= threshold]
|
|
|
63 |
# Setup
|
64 |
device = "mps" if torch.backends.mps.is_built() and torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
65 |
synxp = "dmr76/mmx_classifier_microblog_ENv02"
|
66 |
model = AutoModelForSequenceClassification.from_pretrained(synxp).to(device)
|
67 |
tokenizer = AutoTokenizer.from_pretrained(synxp)
|
68 |
id2label = model.config.id2label
|
|
|
69 |
# ---->>> Define your Tweet <<<----
|
70 |
tweet = "Best cushioning ever!!! ๐ค๐ค๐ค my zoom vomeros are the bomb๐๐ฝโโ๏ธ๐จ!!! \n @nike #run #training https://randomurl.ai"
|
|
|
71 |
# Clean and Predict
|
72 |
cleaned_tweet = clean_and_parse_tweet(tweet)
|
73 |
probs, labels = predict_tweet(cleaned_tweet, model, tokenizer, device)
|
|
|
74 |
# Print Labels and Probabilities
|
75 |
print("Please don't forget to cite the paper: https://ssrn.com/abstract=4542949 in you use this code")
|
76 |
print(labels, probs)
|