louiseclon3's picture
Update app.py
3fd0b52 verified
import nltk
nltk.download('punkt_tab')
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import torch
tokenizer = AutoTokenizer.from_pretrained("louiseclon3/labeling-legal")
model = AutoModelForSequenceClassification.from_pretrained("louiseclon3/labeling-legal")
id_to_label = {1: 'Parties', 2: 'Agreement Date', 3: 'Effective Date', 4: 'Expiration Date', 5: 'Renewal Term', 6: 'Notice Period To Terminate Renewal', 7: 'Notice Period To Terminate Renewal- Answer', 8: 'Governing Law', 9: 'Most Favored Nation', 10: 'Competitive Restriction Exception', 11: 'Non-Compete', 12: 'Exclusivity', 13: 'No-Solicit Of Customers', 14: 'No-Solicit Of Employees', 15: 'Non-Disparagement', 16: 'Termination For Convenience', 17: 'Rofr/Rofo/Rofn', 18: 'Change Of Control', 19: 'Anti-Assignment', 20: 'Revenue/Profit Sharing', 21: 'Price Restrictions', 22: 'Minimum Commitment', 23: 'Volume Restriction', 24: 'Ip Ownership Assignment', 25: 'Joint Ip Ownership', 26: 'License Grant', 27: 'Non-Transferable License', 28: 'Affiliate License-Licensor', 29: 'Affiliate License-Licensee', 30: 'Unlimited/All-You-Can-Eat-License', 31: 'Irrevocable Or Perpetual License', 32: 'Post-Termination Services', 33: 'Audit Rights', 34: 'Uncapped Liability', 35: 'Cap On Liability', 36: 'Liquidated Damages', 37: 'Warranty Duration', 38: 'Insurance', 39: 'Covenant Not To Sue', 40: 'Third Party Beneficiary', 0: 'other'}
def label(contract_text):
sentences = sent_tokenize(contract_text)
res = []
for text in sentences:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
predicted_label = id_to_label[predictions.item()]
res.append({"text": text, "label": predicted_label})
out = []
for item in res:
if item['label'] != "other":
out.append(item)
return out
interface = gr.Interface(fn=label, inputs="text", outputs="json", title="Contract clause labeling", description="Input contract for labeling.",)
if __name__ == "__main__":
interface.launch(show_error=True, debug=True, share=True)