Skip to content

Commit

Permalink
Merge pull request #101 from ptmminh/fix/auth_err
Browse files Browse the repository at this point in the history
fix: meaningful error when authentication fails
  • Loading branch information
Olen authored Apr 20, 2024
2 parents f05b059 + 8dde6af commit 1aaf7ff
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions spond/spond.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import aiohttp


class AuthenticationError(Exception):
pass


class Spond:
def __init__(self, username, password):
self.username = username
Expand All @@ -31,7 +35,10 @@ async def login(self):
data = {"email": self.username, "password": self.password}
async with self.clientsession.post(login_url, json=data) as r:
login_result = await r.json()
self.token = login_result["loginToken"]
self.token = login_result.get("loginToken", None)
if self.token is None:
err_msg = f"Login failed. Response received: {login_result}"
raise AuthenticationError(err_msg)

api_chat_url = f"{self.api_url}chat"
headers = {
Expand All @@ -43,6 +50,19 @@ async def login(self):
self.chat_url = result["url"]
self.auth = result["auth"]

def require_authentication(func: callable):
async def wrapper(self, *args, **kwargs):
if not self.token:
try:
await self.login()
except AuthenticationError as e:
await self.clientsession.close()
raise e
return await func(self, *args, **kwargs)

return wrapper

@require_authentication
async def get_groups(self):
"""
Get all groups.
Expand All @@ -53,13 +73,12 @@ async def get_groups(self):
list of dict
Groups; each group is a dict.
"""
if not self.token:
await self.login()
url = f"{self.api_url}groups/"
async with self.clientsession.get(url, headers=self.auth_headers) as r:
self.groups = await r.json()
return self.groups

@require_authentication
async def get_group(self, uid) -> dict:
"""
Get a group by unique ID.
Expand All @@ -74,15 +93,15 @@ async def get_group(self, uid) -> dict:
-------
Details of the group.
"""
if not self.token:
await self.login()

if not self.groups:
await self.get_groups()
for group in self.groups:
if group["id"] == uid:
return group
raise IndexError

@require_authentication
async def get_person(self, user) -> dict:
"""
Get a member or guardian by matching various identifiers.
Expand All @@ -98,8 +117,6 @@ async def get_person(self, user) -> dict:
-------
Member or guardian's details.
"""
if not self.token:
await self.login()
if not self.groups:
await self.get_groups()
for group in self.groups:
Expand All @@ -126,14 +143,14 @@ async def get_person(self, user) -> dict:
return guardian
raise IndexError

@require_authentication
async def get_messages(self):
if not self.token:
await self.login()
url = f"{self.chat_url}/chats/?max=10"
headers = {"auth": self.auth}
async with self.clientsession.get(url, headers=headers) as r:
return await r.json()

@require_authentication
async def _continue_chat(self, chat_id, text):
"""
Send a given text in an existing given chat.
Expand All @@ -152,14 +169,13 @@ async def _continue_chat(self, chat_id, text):
dict
Result of the sending.
"""
if not self.token:
await self.login()
url = f"{self.chat_url}/messages"
data = {"chatId": chat_id, "text": text, "type": "TEXT"}
headers = {"auth": self.auth}
r = await self.clientsession.post(url, json=data, headers=headers)
return await r.json()

@require_authentication
async def send_message(self, text, user=None, group_uid=None, chat_id=None):
"""
Start a new chat or continue an existing one.
Expand Down Expand Up @@ -192,8 +208,6 @@ async def send_message(self, text, user=None, group_uid=None, chat_id=None):
"error": "wrong usage, group_id and user_id needed or continue chat with chat_id"
}

if not self.token:
await self.login()
user_obj = await self.get_person(user)
if user_obj:
user_uid = user_obj["profile"]["id"]
Expand All @@ -210,6 +224,7 @@ async def send_message(self, text, user=None, group_uid=None, chat_id=None):
r = await self.clientsession.post(url, json=data, headers=headers)
return await r.json()

@require_authentication
async def get_events(
self,
group_id: Optional[str] = None,
Expand Down Expand Up @@ -259,8 +274,6 @@ async def get_events(
list of dict
Events; each event is a dict.
"""
if not self.token:
await self.login()
url = (
f"{self.api_url}sponds/?"
f"max={max_events}"
Expand All @@ -281,6 +294,7 @@ async def get_events(
self.events = await r.json()
return self.events

@require_authentication
async def get_event(self, uid) -> dict:
"""
Get an event by unique ID.
Expand All @@ -295,15 +309,14 @@ async def get_event(self, uid) -> dict:
-------
Details of the event.
"""
if not self.token:
await self.login()
if not self.events:
await self.get_events()
for event in self.events:
if event["id"] == uid:
return event
raise IndexError

@require_authentication
async def update_event(self, uid, updates: dict):
"""
Updates an existing event.
Expand All @@ -320,8 +333,6 @@ async def update_event(self, uid, updates: dict):
json results of post command
"""
if not self.token:
await self.login()
if not self.events:
await self.get_events()
for event in self.events:
Expand Down

0 comments on commit 1aaf7ff

Please sign in to comment.