diff --git a/app/app/settings.py b/app/app/settings.py index cb1ce44..278fe67 100644 --- a/app/app/settings.py +++ b/app/app/settings.py @@ -38,8 +38,11 @@ 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', - + 'rest_framework', + 'rest_framework.authtoken', + 'drf_spectacular', 'core', + 'user', ] MIDDLEWARE = [ @@ -140,3 +143,8 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' AUTH_USER_MODEL = 'core.User' + +REST_FRAMEWORK = { + # Enables rest_framework to generate api schema using drf-spectacular + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', +} diff --git a/app/app/urls.py b/app/app/urls.py index 9a83572..66ad877 100644 --- a/app/app/urls.py +++ b/app/app/urls.py @@ -13,9 +13,20 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ +from drf_spectacular.views import ( + SpectacularAPIView, + SpectacularSwaggerView, +) from django.contrib import admin -from django.urls import path +from django.urls import path, include urlpatterns = [ path('admin/', admin.site.urls), + path('api/schema/', SpectacularAPIView.as_view(), name='api-schema'), + path( + 'api/docs/', + SpectacularSwaggerView.as_view(url_name='api-schema'), + name='api-docs' + ), + path('api/user/', include('user.urls')), ] diff --git a/app/user/__init__.py b/app/user/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/user/apps.py b/app/user/apps.py new file mode 100644 index 0000000..36cce4c --- /dev/null +++ b/app/user/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class UserConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'user' diff --git a/app/user/serializers.py b/app/user/serializers.py new file mode 100644 index 0000000..2561122 --- /dev/null +++ b/app/user/serializers.py @@ -0,0 +1,61 @@ +""" +Serializers for the user API view +""" +from django.contrib.auth import (get_user_model, + authenticate) +from django.utils.translation import gettext as _ + +from rest_framework import serializers + + +class UserSerializer(serializers.ModelSerializer): + """Serializer for the user object""" + + class Meta: + model = get_user_model() + fields = ('email', 'password', 'name') + extra_kwargs = { + 'password': { + 'write_only': True, + 'min_length': 5, + } + } + + def create(self, validated_data): + """Create and return a user with encrypted password.""" + return get_user_model().objects.create_user(**validated_data) + + def update(self, instance, validated_data): + """Update and return user""" + password = validated_data.pop('password', None) + user = super().update(instance, validated_data) + + if password: + user.set_password(password) + user.save() + + return user + + +class AuthTokenSerializer(serializers.Serializer): + """Serializer for the user auth token""" + email = serializers.EmailField() + password = serializers.CharField( + style={'input_type': 'password'}, + trim_whitespace=False + ) + + def validate(self, attrs): + """Validate and authenticate the user.""" + email = attrs.get('email') + password = attrs.get('password') + user = authenticate( + request=self.context.get('request'), + username=email, + password=password, + ) + if not user: + msg = _('Unable to authenticate with provided credentials.') + raise serializers.ValidationError(msg, code='authorization') + attrs['user'] = user + return attrs diff --git a/app/user/tests/__init__.py b/app/user/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/user/tests/test_user_api.py b/app/user/tests/test_user_api.py new file mode 100644 index 0000000..2560632 --- /dev/null +++ b/app/user/tests/test_user_api.py @@ -0,0 +1,156 @@ +""" +Tests for the user API +""" +from django.test import TestCase +from django.contrib.auth import get_user_model +from django.urls import reverse + +from rest_framework import status +from rest_framework.test import APIClient + +CREATE_USER_URL = reverse('user:create') +TOKEN_URL = reverse('user:token') +ME_URL = reverse('user:me') + + +def create_user(**params): + """Create and return a new user""" + return get_user_model().objects.create_user(**params) + + +class PublicUserAPITests(TestCase): + """Tests the public features of the user API""" + + def setUp(self): + self.client = APIClient() + + def test_create_user_success(self): + """Test that creating a new user is successful""" + payload = { + 'email': 'test@example.com', + 'password': 'testpass123', + 'name': 'Test User', + } + res = self.client.post(CREATE_USER_URL, payload) + + self.assertEqual(res.status_code, status.HTTP_201_CREATED) + user = get_user_model().objects.get(email=payload.get('email')) + self.assertTrue(user.check_password(payload.get('password'))) + self.assertNotIn('password', res.data) + + def test_user_with_email_exist_error(self): + """Tests that error is returned when a user with email address exists + """ + payload = { + 'email': 'test@example.com', + 'password': 'testpass123', + 'name': 'Test User', + } + create_user(**payload) + res = self.client.post(CREATE_USER_URL, payload) + + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + def test_password_too_short_error(self): + """Tests that an error is returned if password < 5 characters""" + payload = { + 'email': 'test@example.com', + 'password': 'tes', + 'name': 'Test User', + } + res = self.client.post(CREATE_USER_URL, payload) + + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + user_exists = get_user_model().objects.filter( + email=payload.get('email') + ).exists() + self.assertFalse(user_exists) + + def test_create_token_for_user(self): + """Test generates tokens for valid credentials.""" + user_details = { + 'email': 'test@example.com', + 'password': 'testpass123', + 'name': 'Test User', + } + create_user(**user_details) + payload = { + 'email': user_details.get('email'), + 'password': user_details.get('password'), + } + res = self.client.post(TOKEN_URL, payload) + + self.assertIn('token', res.data) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + def test_create_token_bad_credentials(self): + """Test returns error is credentials invalid.""" + create_user(email='test@example.com', password='goodpass') + payload = { + 'email': 'test@example.com', + 'password': 'badpass', + } + res = self.client.post(TOKEN_URL, payload) + + self.assertNotIn('token', res.data) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + def test_create_token_blank_password(self): + """Test passing blank password returns an error.""" + create_user(email='test@example.com', password='goodpass') + payload = { + 'email': 'test@example.com', + 'password': '', + } + res = self.client.post(TOKEN_URL, payload) + + self.assertNotIn('token', res.data) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + def test_retrieve_user_unauthorized(self): + """Test authentication is required for users.""" + res = self.client.get(ME_URL) + + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + +class PrivateUserAPITests(TestCase): + """Test API requests that require authentication.""" + + def setUp(self): + self.user = create_user( + email='test@example.com', + password='testpass123', + name='Test user', + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + def test_retrieve_profile_success(self): + """Tests retrieving profile for authenticated users""" + res = self.client.get(ME_URL) + + self.assertEqual(res.status_code, status.HTTP_200_OK) + self.assertEqual(res.data, { + 'name': self.user.name, + 'email': self.user.email, + }) + + def test_post_me_not_allowed(self): + """Test POST method is not allowed for the me/ endpoint.""" + res = self.client.post(ME_URL, {}) + + self.assertEqual(res.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_update_user_profile(self): + """Tests updating the profile for authenticated user""" + payload = { + 'name': 'Updated name', + 'password': 'newpassword123', + } + res = self.client.patch(ME_URL, payload) + + self.user.refresh_from_db() + self.assertEqual(self.user.name, payload.get('name')) + self.assertTrue(self.user.check_password(payload.get('password'))) + self.assertEqual(res.status_code, status.HTTP_200_OK) diff --git a/app/user/urls.py b/app/user/urls.py new file mode 100644 index 0000000..01e9cf4 --- /dev/null +++ b/app/user/urls.py @@ -0,0 +1,14 @@ +""" +URL mappings for the user API +""" +from django.urls import path + +from user import views + +app_name = 'user' + +urlpatterns = ( + path('create/', views.CreateUserView.as_view(), name='create'), + path('token/', views.CreateTokenView.as_view(), name='token'), + path('me/', views.ManageUserView.as_view(), name='me'), +) \ No newline at end of file diff --git a/app/user/views.py b/app/user/views.py new file mode 100644 index 0000000..e02cb2a --- /dev/null +++ b/app/user/views.py @@ -0,0 +1,31 @@ +""" +Views for the user API. +""" +from rest_framework import generics, authentication, permissions +from rest_framework.authtoken.views import ObtainAuthToken +from rest_framework.settings import api_settings + +from user.serializers import (UserSerializer, + AuthTokenSerializer,) + + +class CreateUserView(generics.CreateAPIView): + """Create a new user in the system.""" + serializer_class = UserSerializer + + +class ManageUserView(generics.RetrieveUpdateAPIView): + """Manage the authenticated user.""" + serializer_class = UserSerializer + authentication_classes = (authentication.TokenAuthentication,) + permission_classes = (permissions.IsAuthenticated,) + + def get_object(self): + """Retrieve and return the authenticated user.""" + return self.request.user + + +class CreateTokenView(ObtainAuthToken): + """Create a new auth token for the user.""" + serializer_class = AuthTokenSerializer + renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8d29a59..8d07584 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ Django>=3.2.4,<3.3 djangorestframework>=3.12.4,<3.13 -psycopg2>=2.8.6,<2.9 \ No newline at end of file +psycopg2>=2.8.6,<2.9 +drf-spectacular>=0.15.1,<0.16 \ No newline at end of file