nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame
8.34 kB
"""Gmail source."""
import base64
import dataclasses
import os.path
import random
import re
from datetime import datetime
from time import sleep
from typing import TYPE_CHECKING, Any, Iterable, Optional
from pydantic import Field as PydanticField
from typing_extensions import override
from ..env import data_path
from ..schema import Item, field
from ..utils import log
from .source import Source, SourceSchema
if TYPE_CHECKING:
from google.oauth2.credentials import Credentials
# If modifying these scopes, delete the token json file.
_SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']
_GMAIL_CONFIG_DIR = os.path.join(data_path(), '.gmail')
_TOKEN_FILENAME = 'token.json'
_CREDS_FILENAME = 'credentials.json'
_NUM_RETRIES = 10
_MAX_NUM_THREADS = 30_000
_UNWRAP_PATTERN = re.compile(r'(\S)\n(\S)')
HTTP_PATTERN = re.compile(r'https?://[^\s]+')
class GmailSource(Source):
"""Connects to your Gmail and loads the text of your emails.
**One time setup**
Download the OAuth credentials file from the
[Google Cloud Console](https://console.cloud.google.com/apis/credentials) and save it to the
correct location. See
[guide](https://developers.google.com/gmail/api/quickstart/python#authorize_credentials_for_a_desktop_application)
for details.
"""
name = 'gmail'
credentials_file: str = PydanticField(
description='Path to the OAuth credentials file.',
default=os.path.join(_GMAIL_CONFIG_DIR, _CREDS_FILENAME))
_creds: Optional['Credentials'] = None
class Config:
# Language is required even though it has a default value.
schema_extra = {'required': ['credentials_file']}
@override
def setup(self) -> None:
try:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
except ImportError:
raise ImportError('Could not import dependencies for the "gmail" source. '
'Please install with pip install lilacai[gmail]')
# The token file stores the user's access and refresh tokens, and is created automatically when
# the authorization flow completes for the first time.
token_filepath = os.path.join(_GMAIL_CONFIG_DIR, _TOKEN_FILENAME)
if os.path.exists(token_filepath):
self._creds = Credentials.from_authorized_user_file(token_filepath, _SCOPES)
# If there are no (valid) credentials available, let the user log in.
if not self._creds or not self._creds.valid:
if self._creds and self._creds.expired and self._creds.refresh_token:
self._creds.refresh(Request())
else:
if not os.path.exists(self.credentials_file):
raise ValueError(
f'Could not find the OAuth credentials file at "{self.credentials_file}". Make sure to '
'download it from the Google Cloud Console and save it to the correct location.')
flow = InstalledAppFlow.from_client_secrets_file(self.credentials_file, _SCOPES)
self._creds = flow.run_local_server()
os.makedirs(os.path.dirname(token_filepath), exist_ok=True)
# Save the token for the next run.
with open(token_filepath, 'w') as token:
token.write(self._creds.to_json())
@override
def source_schema(self) -> SourceSchema:
return SourceSchema(
fields={
'body': field('string'),
'snippet': field('string'),
'dates': field(fields=['string']),
'subject': field('string'),
})
@override
def process(self) -> Iterable[Item]:
try:
from email_reply_parser import EmailReplyParser
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError('Could not import dependencies for the "gmail" source. '
'Please install with pip install lilacai[gmail]')
# Call the Gmail API
service = build('gmail', 'v1', credentials=self._creds)
# threads.list API
threads_resource = service.users().threads()
thread_batch: list[Item] = []
retry_batch: set[str] = set()
num_retries = 0
num_threads_fetched = 0
def _thread_fetched(request_id: str, response: Any, exception: Optional[HttpError]) -> None:
if exception is not None:
retry_batch.add(request_id)
return
replies: list[str] = []
dates: list[str] = []
snippets: list[str] = []
subject: Optional[str] = None
for msg in response['messages']:
epoch_sec = int(msg['internalDate']) / 1000.
date = datetime.fromtimestamp(epoch_sec).strftime('%Y-%m-%d %H:%M:%S')
dates.append(date)
if 'snippet' in msg:
snippets.append(msg['snippet'])
email_info = _parse_payload(msg['payload'])
subject = subject or email_info.subject
parsed_parts: list[str] = []
for body in email_info.parts:
if not body:
continue
text = base64.urlsafe_b64decode(body).decode('utf-8')
text = EmailReplyParser.parse_reply(text)
# Unwrap text.
text = _UNWRAP_PATTERN.sub('\\1 \\2', text)
# Remove URLs.
text = HTTP_PATTERN.sub('', text)
if text:
parsed_parts.append(text)
if email_info.sender and parsed_parts:
parsed_parts = [
f'--------------------{email_info.sender}--------------------', *parsed_parts
]
if parsed_parts:
replies.append('\n'.join(parsed_parts))
if replies:
thread_batch.append({
'body': '\n\n'.join(replies),
'snippet': '\n'.join(snippets) if snippets else None,
'dates': dates,
'subject': subject,
})
if request_id in retry_batch:
retry_batch.remove(request_id)
# First request.
thread_list_req = threads_resource.list(userId='me', includeSpamTrash=False) or None
thread_list = thread_list_req.execute(num_retries=_NUM_RETRIES) if thread_list_req else None
while (num_threads_fetched < _MAX_NUM_THREADS and thread_list and thread_list_req):
batch = service.new_batch_http_request(callback=_thread_fetched)
threads = thread_list['threads'] if 'threads' in thread_list else []
for gmail_thread in threads:
thread_id = gmail_thread['id'] if 'id' in gmail_thread else None
if not thread_id:
continue
if not retry_batch or (thread_id in retry_batch):
batch.add(
service.users().threads().get(userId='me', id=thread_id, format='full'),
request_id=thread_id)
batch.execute()
num_threads_fetched += len(thread_batch)
yield from thread_batch
thread_batch = []
if retry_batch:
log(f'Failed to fetch {len(retry_batch)} threads. Retrying...')
timeout = 2**(num_retries - 1) + random.uniform(0, 1)
sleep(timeout)
num_retries += 1
else:
retry_batch = set()
num_retries = 0
# Fetch next page.
thread_list_req = threads_resource.list_next(thread_list_req, thread_list)
thread_list = thread_list_req.execute(num_retries=_NUM_RETRIES) if thread_list_req else None
@dataclasses.dataclass
class EmailInfo:
"""Stores parsed information about an email."""
sender: Optional[str] = None
subject: Optional[str] = None
parts: list[bytes] = dataclasses.field(default_factory=list)
def _get_header(payload: Any, name: str) -> Optional[str]:
if 'headers' not in payload:
return None
values = [h['value'] for h in payload['headers'] if h['name'].lower().strip() == name]
return values[0] if values else None
def _parse_payload(payload: Any) -> EmailInfo:
sender = _get_header(payload, 'from')
subject = _get_header(payload, 'subject')
parts: list[bytes] = []
# Process the message body.
if 'mimeType' in payload and 'text/plain' in payload['mimeType']:
if 'body' in payload and 'data' in payload['body']:
parts.append(payload['body']['data'].encode('ascii'))
# Process the message parts.
for part in payload.get('parts', []):
email_info = _parse_payload(part)
sender = sender or email_info.sender
subject = subject or email_info.subject
parts.extend(email_info.parts)
return EmailInfo(sender, subject, parts)