Skip to content
This repository has been archived by the owner on Feb 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #138 from justinjeffery-ipf/get_current_user
Browse files Browse the repository at this point in the history
Updates for logged in user
  • Loading branch information
Justin Jeffery authored Aug 17, 2022
2 parents 842f15c + f0d0dbe commit f3ce991
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 17 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 5.0.4 (2022-08-17)
### Fix
* Fix in Users model
* Added get_user to get information about logged in user during init
* Added count function to tables

## 5.0.3 (2022-08-11)
### Fix
* Deprecate IPF v4.X from package
Expand Down
16 changes: 13 additions & 3 deletions ipfabric/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseSettings

from ipfabric import models
from ipfabric.settings.user_mgmt import User

logger = logging.getLogger()
DEFAULT_ID = "$last"
Expand Down Expand Up @@ -74,11 +75,20 @@ def __init__(
raise RuntimeError("IP Fabric Token or Username/Password not provided.")

self.auth = HeaderApiKey(token) if token else PasswordCredentials(base_url, username, password)

# Get Snapshots, by doing that we are also ensuring the token is valid
# Get Current User, by doing that we are also ensuring the token is valid
self.user = self.get_user()
self.snapshots = self.get_snapshots()
self.snapshot_id = snapshot_id

def get_user(self):
"""
Gets current logged in user information.
:return: User: User model of logged in user
"""
resp = self.get("users/me")
resp.raise_for_status()
return User(**resp.json())

def check_version(self, api_version, base_url):
"""
Checks API Version and returns the version to use in the URL and the OS Version
Expand All @@ -88,7 +98,7 @@ def check_version(self, api_version, base_url):
"""
if api_version == "v1":
raise RuntimeError("IP Fabric Version < 5.0 support has been dropped, please use ipfabric==4.4.3")
dist_ver = get_distribution("ipfabric").version.split('.')
dist_ver = get_distribution("ipfabric").version.split(".")
api_version = parse_version(api_version) if api_version else parse_version(f"{dist_ver[0]}.{dist_ver[1]}")

resp = self.get(urljoin(base_url, "api/version"), headers={"Content-Type": "application/json"})
Expand Down
11 changes: 6 additions & 5 deletions ipfabric/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def wrapper(self, url, *args, **kwargs):
kwargs["filters"] = loads(kwargs["filters"])
path = urlparse(url or kwargs["url"]).path
r = re.search(r"(api/)?v\d(\.\d)?/", path)
url = path[r.end():] if r else path
url = url[1:] if url[0] == '/' else url
url = path[r.end() :] if r else path
url = url[1:] if url[0] == "/" else url
return func(self, url, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -187,9 +187,10 @@ def _ipf_pager(
self._ipf_pager(url, payload, data, limit=limit, start=start + limit)
return data

def get_count(self, url: str, snapshot_id: Optional[str] = None):
payload = dict(columns=["id"], pagination=dict(limit=1, start=0))
payload["snapshot"] = snapshot_id or self.snapshot_id
def get_count(self, url: str, filters: Optional[Union[dict, str]] = None, snapshot_id: Optional[str] = None):
payload = dict(columns=["id"], pagination=dict(limit=1, start=0), snapshot=snapshot_id or self.snapshot_id)
if filters:
payload["filters"] = filters
res = self.post(url, json=payload)
res.raise_for_status()
return res.json()["_meta"]["count"]
45 changes: 45 additions & 0 deletions ipfabric/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,38 @@ class Table(BaseModel):
def name(self):
return self.endpoint.split("/")[-1]

def fetch(
self,
columns: list = None,
filters: Optional[dict] = None,
snapshot_id: Optional[str] = None,
reports: Optional[str] = None,
sort: Optional[dict] = None,
limit: Optional[int] = 1000,
start: Optional[int] = 0,
):
"""
Gets all data from corresponding endpoint
:param columns: list: Optional columns to return, default is all
:param filters: dict: Optional filters
:param snapshot_id: str: Optional snapshot ID to override class
:param reports: str: String of frontend URL where the reports are displayed
:param sort: dict: Dictionary to apply sorting: {"order": "desc", "column": "lastChange"}
:param limit: int: Default to 1,000 rows
:param start: int: Starts at 0
:return: list: List of Dictionaries
"""
return self.client.fetch(
self.endpoint,
columns=columns,
filters=filters,
snapshot_id=snapshot_id,
reports=reports,
sort=sort,
limit=limit,
start=start,
)

def all(
self,
columns: list = None,
Expand All @@ -117,6 +149,19 @@ def all(
sort=sort,
)

def count(self, filters: Optional[dict] = None, snapshot_id: Optional[str] = None):
"""
Gets count of table
:param filters: dict: Optional filters
:param snapshot_id: str: Optional snapshot ID to override class
:return: int: Count
"""
return self.client.get_count(
self.endpoint,
filters=filters,
snapshot_id=snapshot_id,
)


class Inventory(BaseModel):
client: Any
Expand Down
2 changes: 1 addition & 1 deletion ipfabric/settings/user_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class User(BaseModel):
sso_provider: Optional[Any] = Field(alias="ssoProvider")
domains: Optional[Any] = Field(alias="domainSuffixes")
role_names: Optional[list] = Field(alias="roleNames", default_factory=list)
role_ids: list = Field(alias="roleNames")
role_ids: list = Field(alias="roleIds")
ldap_id: Any = Field(alias="ldapId")
timezone: str

Expand Down
9 changes: 3 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pydantic==1.9.2; python_full_version >= "3.6.1"
pyjwt==2.4.0; python_full_version >= "3.7.1" and python_full_version < "4.0.0" and python_version >= "3.6"
python-dateutil==2.8.2; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.3.0")
python-dotenv==0.20.0; python_version >= "3.5"
pytz==2022.1
pytz==2022.2.1
rfc3986==1.5.0; python_full_version >= "3.7.1" and python_full_version < "4.0.0" and python_version >= "3.7"
six==1.16.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.3.0"
sniffio==1.2.0; python_full_version >= "3.7.1" and python_full_version < "4.0.0" and python_version >= "3.7"
Expand Down
1 change: 1 addition & 0 deletions tests/unittests/settings/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def mock_pager(*args):
"domainSuffixes": "",
"email": "justin.jeffrey@ipfabric.io",
"roleNames": ["admin"],
"roleIds": ["admin"],
"timezone": "UTC"
}
]
Expand Down
14 changes: 13 additions & 1 deletion tests/unittests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ipfabric import IPFClient
from ipfabric.client import check_format
from ipfabric.models import Snapshot
from ipfabric.settings.user_mgmt import User


class Decorator(unittest.TestCase):
Expand Down Expand Up @@ -50,10 +51,11 @@ class Client(unittest.TestCase):
@patch("httpx.Client.__init__", return_value=None)
@patch("httpx.Client.headers")
@patch("httpx.Client.base_url")
@patch("ipfabric.IPFClient.get_user")
@patch("ipfabric.IPFClient.check_version")
@patch("ipfabric.IPFClient.get_snapshots")
@patch("ipfabric.models.Inventory")
def setUp(self, inventory, snaps, check_version, base_url, headers, mock_client):
def setUp(self, inventory, snaps, check_version, get_user, base_url, headers, mock_client):
snaps.return_value = {
"$last": Snapshot(
**{
Expand All @@ -77,6 +79,16 @@ def setUp(self, inventory, snaps, check_version, base_url, headers, mock_client)
check_version.return_value = "v5", parse_version("v5.0.1")
self.ipf = IPFClient(base_url="https://google.com", token='token')

@patch("httpx.Client.get")
def test_get_user(self, get):
get().is_error = None
get().json.return_value = {"email": "admin@ipfabric.io", "isLocal": True, "timezone": "UTC",
"username": "admin", "active": True, "ldapId": None, "id": "863",
"roleIds": ["admin"]}
user = self.ipf.get_user()
self.assertIsInstance(user, User)
self.assertEqual(user.username, 'admin')

@patch("httpx.Client.get")
def test_check_version(self, get):
get().is_error = None
Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def test_table_all(self, MockClient):
MockClient.fetch_all.return_value = list()
self.assertEqual(table.all(), list())

@patch("ipfabric.IPFClient")
def test_table_fetch(self, MockClient):
table = models.Table(client=MockClient, endpoint="/network/ip")
MockClient.fetch.return_value = list()
self.assertEqual(table.fetch(), list())

@patch("ipfabric.IPFClient")
def test_table_count(self, MockClient):
table = models.Table(client=MockClient, endpoint="/network/ip")
MockClient.get_count.return_value = 1
self.assertEqual(table.count(), 1)

def test_inventory(self):
i = models.Inventory(client=MagicMock())
self.assertIsInstance(i.vendors, models.Table)
Expand Down

0 comments on commit f3ce991

Please sign in to comment.