commit 98680c95e85b49b7115a69ffb0ba738b350d1631 Author: Koichi Edagawa Date: Fri Aug 21 16:21:32 2020 +0900 Support of OAuth2.0 Supported OAuth2.0 as the part of the feature of branching VNFM and NFVO in Tacker. This includes Client/Basic certification. Implements: blueprint support-vnfm-operations Spec: https://specs.openstack.org/openstack/tacker-specs/specs/victoria/support-sol003-vnfm-operations.html Change-Id: I025579dcfe6f0b020bb00dcb90a7d33e49b7cac9 (cherry picked from commit a7dc3ab6b5a11f2f34178cab74ee6bafdeeab918) diff --git a/tacker/auth.py b/tacker/auth.py index 10ed4f6..04ac726 100644 --- a/tacker/auth.py +++ b/tacker/auth.py @@ -12,12 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. +import abc from oslo_config import cfg from oslo_log import log as logging from oslo_middleware import base +import requests +import threading import webob.dec import webob.exc +from tacker.common import utils from tacker import context LOG = logging.getLogger(__name__) @@ -40,6 +44,248 @@ class TackerKeystoneContext(base.ConfigurableMiddleware): return self.application +class _BearerAuth(requests.auth.AuthBase): + """Attaches HTTP Bearer Authentication to the given Request object.""" + + def __init__(self, token): + self.token = token + + def __call__(self, r): + r.headers["Authorization"] = "Bearer " + self.token + return r + + +class _OAuth2GrantBase(metaclass=abc.ABCMeta): + """Base class that all OAuth2.0 grant type implementations derive from.""" + + grant_type = None + + @abc.abstractmethod + def get_accsess_token(self): + pass + + +class _ClientCredentialsGrant(_OAuth2GrantBase): + """OAuth2.0 grant type ClientCredentials implementation.""" + + grant_type = 'client_credentials' + + def __init__(self, token_endpoint, client_id, client_password): + super(_ClientCredentialsGrant, self).__init__() + self.token_endpoint = token_endpoint + self.client_id = client_id + self.client_password = client_password + + def get_accsess_token(self): + """Get access token. + + Returns: + dict: access token response. + """ + kwargs = { + 'headers': { + 'Connection': 'keep-alive', + 'Content-Type': 'application/x-www-form-urlencoded'}, + 'data': { + 'grant_type': self.grant_type}, + 'timeout': cfg.CONF.authentication.timeout} + + basic_auth_request = _BasicAuthSession( + self.client_id, self.client_password) + + LOG.info( + "Get Access Token, Connecting to ".format( + self.token_endpoint)) + LOG.info("Request Headers={}".format(kwargs.get('headers'))) + LOG.info("Request Body={}".format(kwargs.get('data'))) + + response = basic_auth_request.get(self.token_endpoint, **kwargs) + response.raise_for_status() + + response_body = response.json() + LOG.info("[RES] Headers={}".format(response.headers)) + LOG.info("[RES] Body={}".format(response_body)) + + return response_body + + +class _OAuth2Session(requests.Session): + """Provides OAuth 2.0 authentication.""" + + def __init__(self, grant): + super(_OAuth2Session, self).__init__() + self.grant = grant + self.__access_token_info = {} + self.__lock = threading.RLock() + + def request(self, method, url, **kwargs): + """Override function.""" + kwargs['auth'] = _BearerAuth( + self.__access_token_info.get('access_token')) + + response = super().request(method, url, **kwargs) + if response.status_code == 401: + LOG.error( + 'Authentication error {}, details={}'.format( + response, response.text)) + self.apply_access_token_info() + + return response + + def apply_access_token_info(self): + """Get access token.""" + try: + self.__set_access_token_info(self.grant.get_accsess_token()) + self.schedule_refrash_token() + except requests.exceptions.RequestException as e: + if hasattr(e, 'response') and e.response: + LOG.error( + "Get Access Token, error details={}".format( + e.response.json())) + LOG.error(e) + + def __set_access_token_info(self, update_dict): + with self.__lock: + self.__access_token_info = update_dict + + def schedule_refrash_token(self): + """expires_in Scheduler registration at expiration.""" + if not ('expires_in' in self.__access_token_info): + LOG.debug("'expires_in' does not exist in the response body.") + return + + try: + expires_in = int(self.__access_token_info.get('expires_in')) + expires_in_timer = threading.Timer( + expires_in, self.apply_access_token_info) + expires_in_timer.start() + + LOG.info( + "expires_in=<{}> exist, scheduler regist.".format(expires_in)) + except (ValueError, TypeError): + pass + + +class _BasicAuthSession(requests.Session): + """Provide Basic authentication.""" + + def __init__(self, user_name, password): + super(_BasicAuthSession, self).__init__() + self.user_name = user_name + self.password = password + self.auth = requests.auth.HTTPBasicAuth( + self.user_name, self.password) + + def request(self, method, url, **kwargs): + """Override function.""" + kwargs['auth'] = self.auth + return super().request(method, url, **kwargs) + + +class _AuthManager: + + OPTS = [ + cfg.StrOpt('auth_type', + default=None, + choices=['BASIC', 'OAUTH2_CLIENT_CREDENTIALS'], + help="auth_type used for external connection"), + cfg.IntOpt('timeout', + default=20, + help="timeout used for external connection"), + cfg.StrOpt('token_endpoint', + default=None, + help="token_endpoint used to get the oauth2 token"), + cfg.StrOpt('client_id', + default=None, + help="client_id used to get the oauth2 token"), + cfg.StrOpt('client_password', + default=None, + help="client_password used to get the oauth2 token"), + cfg.StrOpt('user_name', + default=None, + help="user_name used in basic authentication"), + cfg.StrOpt('password', + default=None, + help="password used in basic authentication") + ] + cfg.CONF.register_opts(OPTS, group='authentication') + + __DEFAULT_CLIENT = requests.Session() + + def __init__(self): + self.__manages = {} + self.__lock = threading.RLock() + + # local auth setting. + self.set_auth_client( + auth_type=cfg.CONF.authentication.auth_type, + auth_params={'client_id': cfg.CONF.authentication.client_id, + 'client_password': cfg.CONF.authentication.client_password, + 'token_endpoint': cfg.CONF.authentication.token_endpoint, + 'user_name': cfg.CONF.authentication.user_name, + 'password': cfg.CONF.authentication.password}) + + def __empty(self, val): + if val is None: + return True + elif isinstance(val, str): + return val.strip() == '' + + return len(val) == 0 + + def set_auth_client(self, id='local', auth_type=None, auth_params=None): + """Set up an Auth client. + + Args: + id (str, optional): Management ID + auth_type (str, optional): Authentication type. + auth_params (dict, optional): Authentication information. + """ + snakecase_auth_params = utils.convert_camelcase_to_snakecase( + auth_params) + if self.__empty(auth_type) or self.__empty(snakecase_auth_params): + return + + if id in self.__manages: + LOG.debug("Use cache, Auth Managed Id=<{}>".format(id)) + return + + client = self.__DEFAULT_CLIENT + if auth_type == 'BASIC': + client = _BasicAuthSession( + user_name=snakecase_auth_params.get('user_name'), + password=snakecase_auth_params.get('password')) + elif (auth_type == 'OAUTH2_CLIENT_CREDENTIALS' and + not self.__empty(snakecase_auth_params.get('token_endpoint'))): + grant = _ClientCredentialsGrant( + client_id=snakecase_auth_params.get('client_id'), + client_password=snakecase_auth_params.get('client_password'), + token_endpoint=snakecase_auth_params.get('token_endpoint')) + client = _OAuth2Session(grant) + client.apply_access_token_info() + + LOG.info( + "Add to Auth management, id=<{}>, type=<{}>, class=<{}>".format( + id, auth_type, client.__class__.__name__)) + + self.__add_manages(id, client) + + def __add_manages(self, id, client): + with self.__lock: + self.__manages[id] = client + + def get_auth_client(self, id="local"): + """Get the Auth client. + + Args: + id (str, optional): Management ID + + Returns: + based on class. + """ + return self.__manages.get(id, self.__DEFAULT_CLIENT) + + def pipeline_factory(loader, global_conf, **local_conf): """Create a paste pipeline based on the 'auth_strategy' config option.""" pipeline = local_conf[cfg.CONF.auth_strategy] @@ -50,3 +296,6 @@ def pipeline_factory(loader, global_conf, **local_conf): for f in filters: app = f(app) return app + + +auth_manager = _AuthManager() diff --git a/tacker/conductor/conductor_server.py b/tacker/conductor/conductor_server.py index e74bb7a..7ca82dc 100644 --- a/tacker/conductor/conductor_server.py +++ b/tacker/conductor/conductor_server.py @@ -18,11 +18,11 @@ import functools import inspect import json import os -import requests import shutil import sys import time import traceback +import yaml from glance_store import exceptions as store_exceptions @@ -38,8 +38,8 @@ from oslo_utils import timeutils from oslo_utils import uuidutils from sqlalchemy import exc as sqlexc from sqlalchemy.orm import exc as orm_exc -import yaml +from tacker import auth from tacker.common import coordination from tacker.common import csar_utils from tacker.common import exceptions @@ -620,9 +620,13 @@ class Conductor(manager.Manager): notification['timeStamp'] = datetime.datetime.utcnow( ).isoformat() try: + self.__set_auth_subscription(line) + for num in range(CONF.vnf_lcm.retry_num): LOG.warn("send notify[%s]" % json.dumps(notification)) - response = requests.post( + auth_client = auth.auth_manager.get_auth_client( + notification['subscriptionId']) + response = auth_client.post( line.callback_uri.decode(), data=json.dumps(notification)) if response.status_code == 204: @@ -654,7 +658,7 @@ class Conductor(manager.Manager): except Exception as e: LOG.warn("Internal Sever Error[%s]" % str(e)) LOG.warn(traceback.format_exc()) - return -2 + return 99 return 0 @coordination.synchronized('{vnf_instance[id]}') @@ -858,6 +862,37 @@ class Conductor(manager.Manager): error=str(ex) ) + def __set_auth_subscription(self, vnf_lcm_subscription): + def decode(val): + return val if isinstance(val, str) else val.decode() + + if not vnf_lcm_subscription.subscription_authentication: + return + + subscription_authentication = decode( + vnf_lcm_subscription.subscription_authentication) + + authentication = utils.convert_camelcase_to_snakecase( + json.loads(subscription_authentication)) + + if not authentication: + return + + auth_params = {} + auth_type = None + if 'params_basic' in authentication: + auth_params = authentication.get('params_basic') + auth_type = 'BASIC' + elif 'params_oauth2_client_credentials' in authentication: + auth_params = authentication.get( + 'params_oauth2_client_credentials') + auth_type = 'OAUTH2_CLIENT_CREDENTIALS' + + auth.auth_manager.set_auth_client( + id=decode(vnf_lcm_subscription.id), + auth_type=auth_type, + auth_params=auth_params) + def init(args, **kwargs): CONF(args=args, project='tacker', diff --git a/tacker/objects/vnf_lcm_subscriptions.py b/tacker/objects/vnf_lcm_subscriptions.py index f18e03d..10afd65 100644 --- a/tacker/objects/vnf_lcm_subscriptions.py +++ b/tacker/objects/vnf_lcm_subscriptions.py @@ -111,13 +111,16 @@ def _vnf_lcm_subscriptions_show(context, subscriptionId): "where t1.id = t2.subscription_uuid " "and deleted = 0 " "and t1.id = :subsc_id") + result_line = "" try: result = context.session.execute(sql, {'subsc_id': subscriptionId}) + for line in result: + result_line = line except exceptions.NotFound: return '' except Exception as e: raise e - return result + return result_line @db_api.context_manager.reader @@ -193,7 +196,8 @@ def _vnf_lcm_subscriptions_id_get(context, try: result = context.session.execute(sql) - return result + for line in result: + return line except exceptions.NotFound: return '' @@ -287,6 +291,11 @@ class LccnSubscriptionRequest(base.TackerObject, base.TackerPersistentObject): updates = self.obj_clone() db_vnf_lcm_subscriptions = _vnf_lcm_subscriptions_create( self._context, updates, filter) + + LOG.debug( + 'test_log: db_vnf_lcm_subscriptions %s' % + db_vnf_lcm_subscriptions) + return db_vnf_lcm_subscriptions @base.remotable_classmethod diff --git a/tacker/tests/unit/conductor/test_conductor_server.py b/tacker/tests/unit/conductor/test_conductor_server.py index aeebb9a..43c848f 100644 --- a/tacker/tests/unit/conductor/test_conductor_server.py +++ b/tacker/tests/unit/conductor/test_conductor_server.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +import json import os import shutil import sys @@ -27,6 +29,7 @@ from six.moves import urllib import six.moves.urllib.error as urlerr import yaml +from tacker import auth from tacker.common import coordination from tacker.common import csar_utils from tacker.common import exceptions @@ -36,11 +39,14 @@ from tacker import context from tacker.glance_store import store as glance_store from tacker import objects from tacker.objects import fields +from tacker.tests.unit import base as unit_base from tacker.tests.unit.conductor import fakes from tacker.tests.unit.db.base import SqlTestCase from tacker.tests.unit.db import utils as db_utils from tacker.tests.unit.objects import fakes as fake_obj from tacker.tests.unit.vnflcm import fakes as vnflcm_fakes +from tacker.tests.unit.vnfm.infra_drivers.openstack.fixture_data import client +import tacker.tests.unit.vnfm.test_nfvo_client as nfvo_client from tacker.tests import utils from tacker.tests import uuidsentinel @@ -55,7 +61,9 @@ class FakeVNFMPlugin(mock.Mock): pass -class TestConductor(SqlTestCase): +class TestConductor(SqlTestCase, unit_base.FixturedTestCase): + client_fixture_class = client.ClientFixture + sdk_connection_fixure_class = client.SdkConnectionFixture def setUp(self): super(TestConductor, self).setUp() @@ -91,6 +99,41 @@ class TestConductor(SqlTestCase): def _create_vnf_package_vnfd(self): return fakes.get_vnf_package_vnfd() + def _create_subscriptions(self, auth_params=None): + class DummyLcmSubscription: + def __init__(self, auth_params=None): + if auth_params: + self.subscription_authentication = json.dumps( + auth_params).encode() + + self.id = uuidsentinel.lcm_subscription_id.encode() + self.callback_uri = 'https://localhost/callback'.encode() + + def __getattr__(self, name): + try: + return object.__getattr__(self, name) + except AttributeError: + return None + + return [DummyLcmSubscription(auth_params)] + + def assert_auth_basic( + self, + acutual_request, + expected_user_name, + expected_password): + actual_auth = acutual_request._request.headers.get("Authorization") + expected_auth = base64.b64encode( + '{}:{}'.format( + expected_user_name, + expected_password).encode('utf-8')).decode() + self.assertEqual("Basic " + expected_auth, actual_auth) + + def assert_auth_client_credentials(self, acutual_request, expected_token): + actual_auth = acutual_request._request.headers.get( + "Authorization") + self.assertEqual("Bearer " + expected_token, actual_auth) + @mock.patch.object(conductor_server.Conductor, '_onboard_vnf_package') @mock.patch.object(conductor_server, 'revert_upload_vnf_package') @mock.patch.object(csar_utils, 'load_csar_data') @@ -143,9 +186,9 @@ class TestConductor(SqlTestCase): def test_get_vnf_package_vnfd_with_tosca_meta_file_in_csar(self): fake_csar = fakes.create_fake_csar_dir(self.vnf_package.id, self.temp_dir) + expected_data = fakes.get_expected_vnfd_data() result = self.conductor.get_vnf_package_vnfd(self.context, self.vnf_package) - expected_data = fakes.get_expected_vnfd_data() self.assertEqual(expected_data, result) shutil.rmtree(fake_csar) @@ -620,20 +663,11 @@ class TestConductor(SqlTestCase): password=password) self.assertEqual('CREATED', self.vnf_package.onboarding_state) - def test_send_notification_not_found_vnfd(self): - notification = {'vnfInstanceId': 'Test'} - - result = self.conductor.send_notification(self.context, notification) - - self.assertEqual(result, -2) - @mock.patch.object(objects.LccnSubscriptionRequest, 'vnf_lcm_subscriptions_get') - def test_send_notification_not_found_subscription( - self, - mock_vnf_lcm_subscriptions_get): - - mock_vnf_lcm_subscriptions_get.return_value = None + def test_sendNotification_notFoundSubscription(self, + mock_subscriptions_get): + mock_subscriptions_get.return_value = None notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfLcmOperationOccurrenceNotification'} @@ -641,24 +675,18 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) self.assertEqual(result, -1) - mock_vnf_lcm_subscriptions_get.assert_called() + mock_subscriptions_get.assert_called() @mock.patch.object(objects.LccnSubscriptionRequest, 'vnf_lcm_subscriptions_get') - @mock.patch('requests.post') - def test_send_notification_vnf_lcm_operation_occurrence( - self, - mock_post, - mock_vnf_lcm_subscriptions_get): - - response = mock.Mock() - response.status_code = 204 - mock_post.return_value = response - - m_vnf_lcm_subscriptions = \ - [mock.MagicMock(**fakes.get_vnf_lcm_subscriptions())] - mock_vnf_lcm_subscriptions_get.return_value = \ - m_vnf_lcm_subscriptions + def test_sendNotification_vnfLcmOperationOccurrence(self, + mock_subscriptions_get): + self.requests_mock.register_uri('POST', + "https://localhost/callback", + headers={'Content-Type': 'application/json'}, + status_code=204) + + mock_subscriptions_get.return_value = self._create_subscriptions() notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfLcmOperationOccurrenceNotification', @@ -669,25 +697,53 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) self.assertEqual(result, 0) - mock_vnf_lcm_subscriptions_get.assert_called() - mock_post.assert_called() + mock_subscriptions_get.assert_called() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost") + self.assertEqual(1, req_count) @mock.patch.object(objects.LccnSubscriptionRequest, - 'vnf_lcm_subscriptions_get') - @mock.patch('requests.post') - def test_send_notification_vnf_identifier_creation( - self, - mock_post, - mock_vnf_lcm_subscriptions_get): + 'vnf_lcm_subscriptions_get') + def test_sendNotification_vnfIdentifierCreation(self, + mock_subscriptions_get): + self.requests_mock.register_uri('POST', + "https://localhost/callback", + headers={'Content-Type': 'application/json'}, + status_code=204) + + mock_subscriptions_get.return_value = self._create_subscriptions() + notification = { + 'vnfInstanceId': 'Test', + 'notificationType': 'VnfIdentifierCreationNotification', + 'links': {}} - response = mock.Mock() - response.status_code = 204 - mock_post.return_value = response + result = self.conductor.send_notification(self.context, notification) + + self.assertEqual(result, 0) + mock_subscriptions_get.assert_called() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost") + self.assertEqual(1, req_count) + + @mock.patch.object(objects.LccnSubscriptionRequest, + 'vnf_lcm_subscriptions_get') + def test_sendNotification_with_auth_basic(self, mock_subscriptions_get): + self.requests_mock.register_uri('POST', + "https://localhost/callback", + headers={'Content-Type': 'application/json'}, + status_code=204) + + auth_user_name = 'test_user' + auth_password = 'test_password' + mock_subscriptions_get.return_value = self._create_subscriptions( + {'authType': ['BASIC'], + 'paramsBasic': {'userName': auth_user_name, + 'password': auth_password}}) - m_vnf_lcm_subscriptions = \ - [mock.MagicMock(**fakes.get_vnf_lcm_subscriptions())] - mock_vnf_lcm_subscriptions_get.return_value = \ - m_vnf_lcm_subscriptions notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfIdentifierCreationNotification', @@ -696,25 +752,43 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) self.assertEqual(result, 0) - mock_vnf_lcm_subscriptions_get.assert_called() - mock_post.assert_called() + mock_subscriptions_get.assert_called() - @mock.patch.object(objects.LccnSubscriptionRequest, - 'vnf_lcm_subscriptions_get') - @mock.patch('requests.post') - def test_send_notification_retry_notification( - self, - mock_post, - mock_vnf_lcm_subscriptions_get): + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost") + self.assertEqual(1, req_count) + self.assert_auth_basic( + history[0], + auth_user_name, + auth_password) - response = mock.Mock() - response.status_code = 400 - mock_post.return_value = response + @mock.patch.object(objects.LccnSubscriptionRequest, + 'vnf_lcm_subscriptions_get') + def test_sendNotification_with_auth_client_credentials( + self, mock_subscriptions_get): + auth.auth_manager = auth._AuthManager() + self.requests_mock.register_uri('POST', + "https://localhost/callback", + headers={'Content-Type': 'application/json'}, + status_code=204) + + auth_user_name = 'test_user' + auth_password = 'test_password' + token_endpoint = 'https://oauth2/tokens' + self.requests_mock.register_uri('GET', + token_endpoint, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + mock_subscriptions_get.return_value = self._create_subscriptions( + {'authType': ['OAUTH2_CLIENT_CREDENTIALS'], + 'paramsOauth2ClientCredentials': { + 'clientId': auth_user_name, + 'clientPassword': auth_password, + 'tokenEndpoint': token_endpoint}}) - m_vnf_lcm_subscriptions = \ - [mock.MagicMock(**fakes.get_vnf_lcm_subscriptions())] - mock_vnf_lcm_subscriptions_get.return_value = \ - m_vnf_lcm_subscriptions notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfIdentifierCreationNotification', @@ -723,22 +797,25 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) self.assertEqual(result, 0) - mock_vnf_lcm_subscriptions_get.assert_called() - mock_post.assert_called() - self.assertEqual(mock_post.call_count, 3) + mock_subscriptions_get.assert_called() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost", 'https://oauth2') + self.assertEqual(2, req_count) + self.assert_auth_basic(history[0], auth_user_name, auth_password) + self.assert_auth_client_credentials(history[1], "test_token") @mock.patch.object(objects.LccnSubscriptionRequest, - 'vnf_lcm_subscriptions_get') - @mock.patch('requests.post') - def test_send_notification_send_error(self, - mock_post, - mock_vnf_lcm_subscriptions_get): - mock_post.side_effect = \ - requests.exceptions.HTTPError("MockException") - m_vnf_lcm_subscriptions = \ - [mock.MagicMock(**fakes.get_vnf_lcm_subscriptions())] - mock_vnf_lcm_subscriptions_get.return_value = \ - m_vnf_lcm_subscriptions + 'vnf_lcm_subscriptions_get') + def test_sendNotification_retyNotification(self, + mock_subscriptions_get): + self.requests_mock.register_uri('POST', + "https://localhost/callback", + headers={'Content-Type': 'application/json'}, + status_code=400) + + mock_subscriptions_get.return_value = self._create_subscriptions() notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfIdentifierCreationNotification', @@ -747,16 +824,22 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) self.assertEqual(result, 0) - mock_vnf_lcm_subscriptions_get.assert_called() - mock_post.assert_called() - self.assertEqual(mock_post.call_count, 1) + mock_subscriptions_get.assert_called() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost") + self.assertEqual(3, req_count) @mock.patch.object(objects.LccnSubscriptionRequest, 'vnf_lcm_subscriptions_get') - def test_send_notification_internal_server_error(self, - mock_vnf_lcm_subscriptions_get): - mock_vnf_lcm_subscriptions_get.side_effect = Exception( - "MockException") + def test_sendNotification_sendError(self, + mock_subscriptions_get): + self.requests_mock.register_uri('POST', + "https://localhost/callback", + exc=requests.exceptions.HTTPError("MockException")) + + mock_subscriptions_get.return_value = self._create_subscriptions() notification = { 'vnfInstanceId': 'Test', 'notificationType': 'VnfIdentifierCreationNotification', @@ -764,4 +847,25 @@ class TestConductor(SqlTestCase): result = self.conductor.send_notification(self.context, notification) - self.assertEqual(result, -2) + self.assertEqual(result, 0) + mock_subscriptions_get.assert_called() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, "https://localhost") + self.assertEqual(1, req_count) + + @mock.patch.object(objects.LccnSubscriptionRequest, + 'vnf_lcm_subscriptions_get') + def test_sendNotification_internalServerError( + self, mock_subscriptions_get): + mock_subscriptions_get.side_effect = Exception("MockException") + notification = { + 'vnfInstanceId': 'Test', + 'notificationTypes': 'VnfIdentifierCreationNotification', + 'links': {}} + + result = self.conductor.send_notification(self.context, notification) + + self.assertEqual(result, 99) + mock_subscriptions_get.assert_called() diff --git a/tacker/tests/unit/test_auth.py b/tacker/tests/unit/test_auth.py index 4e93e39..1b59e44 100644 --- a/tacker/tests/unit/test_auth.py +++ b/tacker/tests/unit/test_auth.py @@ -13,11 +13,24 @@ # License for the specific language governing permissions and limitations # under the License. +import ddt +from oslo_config import cfg from oslo_middleware import request_id -import webob - +import requests +from requests_mock.contrib import fixture as requests_mock_fixture from tacker import auth from tacker.tests import base +import tacker.tests.unit.vnfm.test_nfvo_client as nfvo_client + +import threading + +from tacker.tests import uuidsentinel + +from oslo_log import log as logging +from unittest import mock +import webob + +LOG = logging.getLogger(__name__) class TackerKeystoneContextTestCase(base.BaseTestCase): @@ -110,3 +123,564 @@ class TackerKeystoneContextTestCase(base.BaseTestCase): del self.request.headers['X_AUTH_TOKEN'] self.request.get_response(self.middleware) self.assertIsNone(self.context.auth_token) + + +@ddt.ddt +class TestAuthManager(base.BaseTestCase): + + def setUp(self): + super(TestAuthManager, self).setUp() + self.token_endpoint_url = 'https://oauth2/tokens' + self.oauth_url = 'https://oauth2' + self.user_name = 'test_user' + self.password = 'test_password' + auth.auth_manager = auth._AuthManager() + self.requests_mock = self.useFixture(requests_mock_fixture.Fixture()) + + def tearDown(self): + super(TestAuthManager, self).tearDown() + self.addCleanup(mock.patch.stopall) + + def test_init(self): + self.assertEqual(None, cfg.CONF.authentication.auth_type) + self.assertEqual(20, cfg.CONF.authentication.timeout) + self.assertEqual(None, cfg.CONF.authentication.token_endpoint) + self.assertEqual(None, cfg.CONF.authentication.client_id) + self.assertEqual(None, cfg.CONF.authentication.client_password) + self.assertEqual(None, cfg.CONF.authentication.user_name) + self.assertEqual(None, cfg.CONF.authentication.password) + + def test_get_auth_client_oauth2_client_credentials_with_local(self): + cfg.CONF.set_override('auth_type', 'OAUTH2_CLIENT_CREDENTIALS', + group='authentication') + cfg.CONF.set_override('token_endpoint', self.token_endpoint_url, + group='authentication') + cfg.CONF.set_override('client_id', self.user_name, + group='authentication') + cfg.CONF.set_override('client_password', self.password, + group='authentication') + + self.requests_mock.register_uri('GET', + self.token_endpoint_url, + json={'access_token': 'test_token3', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + auth.auth_manager = auth._AuthManager() + client = auth.auth_manager.get_auth_client() + + self.assertIsInstance(client, auth._OAuth2Session) + self.assertEqual( + self.user_name, + client.grant.client_id) + self.assertEqual( + self.password, + client.grant.client_password) + self.assertEqual( + self.token_endpoint_url, + client.grant.token_endpoint) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(1, req_count) + + def test_get_auth_client_basic_with_local(self): + cfg.CONF.set_override('auth_type', 'BASIC', + group='authentication') + cfg.CONF.set_override('user_name', self.user_name, + group='authentication') + cfg.CONF.set_override('password', self.password, + group='authentication') + + auth.auth_manager = auth._AuthManager() + client = auth.auth_manager.get_auth_client() + + self.assertIsInstance(client, auth._BasicAuthSession) + self.assertEqual(self.user_name, client.user_name) + self.assertEqual(self.password, client.password) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(0, req_count) + + def test_get_auth_client_noauth_with_local(self): + cfg.CONF.set_override('auth_type', None, + group='authentication') + + client = auth.auth_manager.get_auth_client() + self.assertIsInstance(client, requests.Session) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(0, req_count) + + def test_get_auth_client_oauth2_client_credentials_with_subscription(self): + self.requests_mock.register_uri('GET', + self.token_endpoint_url, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + params_oauth2_client_credentials = { + 'clientId': self.user_name, + 'clientPassword': self.password, + 'tokenEndpoint': self.token_endpoint_url} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='OAUTH2_CLIENT_CREDENTIALS', + auth_params=params_oauth2_client_credentials) + client = auth.auth_manager.get_auth_client( + id=uuidsentinel.subscription_id) + + self.assertIsInstance(client, auth._OAuth2Session) + self.assertEqual( + self.user_name, + client.grant.client_id) + self.assertEqual( + self.password, + client.grant.client_password) + self.assertEqual( + self.token_endpoint_url, + client.grant.token_endpoint) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(1, req_count) + + def test_get_auth_client_basic_with_subscription(self): + params_basic = { + 'userName': self.user_name, + 'password': self.password} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='BASIC', + auth_params=params_basic) + client = auth.auth_manager.get_auth_client( + id=uuidsentinel.subscription_id) + + self.assertIsInstance(client, auth._BasicAuthSession) + self.assertEqual(self.user_name, client.user_name) + self.assertEqual(self.password, client.password) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(0, req_count) + + def test_set_auth_client_noauth(self): + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type=None, + auth_params={}) + + manages = auth.auth_manager._AuthManager__manages + self.assertNotIn(uuidsentinel.subscription_id, manages) + + def test_set_auth_client_basic(self): + params_basic = { + 'userName': self.user_name, + 'password': self.password} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='BASIC', + auth_params=params_basic) + + manages = auth.auth_manager._AuthManager__manages + self.assertIn(uuidsentinel.subscription_id, manages) + + client = manages.get(uuidsentinel.subscription_id) + self.assertIsInstance(client, auth._BasicAuthSession) + self.assertEqual(self.user_name, client.user_name) + self.assertEqual(self.password, client.password) + + def test_set_auth_client_oauth2_client_credentials(self): + self.requests_mock.register_uri( + 'GET', self.token_endpoint_url, + json={ + 'access_token': 'test_token', 'token_type': 'bearer'}, + headers={ + 'Content-Type': 'application/json'}, + status_code=200) + + params_oauth2_client_credentials = { + 'clientId': self.user_name, + 'clientPassword': self.password, + 'tokenEndpoint': self.token_endpoint_url} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='OAUTH2_CLIENT_CREDENTIALS', + auth_params=params_oauth2_client_credentials) + + manages = auth.auth_manager._AuthManager__manages + self.assertIn(uuidsentinel.subscription_id, manages) + + client = manages.get(uuidsentinel.subscription_id) + self.assertIsInstance(client, auth._OAuth2Session) + self.assertEqual( + self.user_name, + client.grant.client_id) + self.assertEqual( + self.password, + client.grant.client_password) + self.assertEqual( + self.token_endpoint_url, + client.grant.token_endpoint) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(1, req_count) + + def test_set_auth_client_used_chahe(self): + params_basic = { + 'userName': self.user_name, + 'password': self.password} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='BASIC', + auth_params=params_basic) + + params_oauth2_client_credentials = { + 'clientId': self.user_name, + 'clientPassword': self.password, + 'tokenEndpoint': self.token_endpoint_url} + + auth.auth_manager.set_auth_client( + id=uuidsentinel.subscription_id, + auth_type='OAUTH2_CLIENT_CREDENTIALS', + auth_params=params_oauth2_client_credentials) + + manages = auth.auth_manager._AuthManager__manages + self.assertIn(uuidsentinel.subscription_id, manages) + + client = manages.get(uuidsentinel.subscription_id) + self.assertIsInstance(client, auth._BasicAuthSession) + self.assertEqual(self.user_name, client.user_name) + self.assertEqual(self.password, client.password) + + +@ddt.ddt +class TestBasicAuthSession(base.BaseTestCase): + + def setUp(self): + super(TestBasicAuthSession, self).setUp() + self.token_endpoint_url = 'https://oauth2/tokens' + self.nfvo_url = 'http://nfvo.co.jp' + self.user_name = 'test_user' + self.password = 'test_password' + self.requests_mock = self.useFixture(requests_mock_fixture.Fixture()) + + def tearDown(self): + super(TestBasicAuthSession, self).tearDown() + self.addCleanup(mock.patch.stopall) + + @ddt.data('GET', 'PUT', 'POST', 'DELETE', 'PATCH') + def test_request(self, http_method): + client = auth._BasicAuthSession( + user_name=self.user_name, + password=self.password) + + self.requests_mock.register_uri(http_method, + self.nfvo_url, + headers={'Content-Type': 'application/json'}, + status_code=200) + + if http_method == 'GET': + response = client.get( + self.nfvo_url, + params={ + 'sample_key': 'sample_value'}) + elif http_method == 'PUT': + response = client.put( + self.nfvo_url, + data={ + 'sample_key': 'sample_value'}) + elif http_method == 'POST': + response = client.post( + self.nfvo_url, + data={ + 'sample_key': 'sample_value'}) + elif http_method == 'DELETE': + response = client.delete( + self.nfvo_url, + params={ + 'sample_key': 'sample_value'}) + elif http_method == 'PATCH': + response = client.patch( + self.nfvo_url, + data={ + 'sample_key': 'sample_value'}) + + self.assertEqual(200, response.status_code) + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.nfvo_url) + self.assertEqual(1, req_count) + + +@ddt.ddt +class TestOAuth2Session(base.BaseTestCase): + + class MockThread(threading.Timer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def start(self): + super().start() + super().join(60) + + def setUp(self): + super(TestOAuth2Session, self).setUp() + self.token_endpoint_url = 'https://oauth2/tokens' + self.oauth_url = 'https://oauth2' + self.user_name = 'test_user' + self.password = 'test_password' + self.requests_mock = self.useFixture(requests_mock_fixture.Fixture()) + + def tearDown(self): + super(TestOAuth2Session, self).tearDown() + self.addCleanup(mock.patch.stopall) + + def test_apply_access_token_info(self): + res_mock = { + 'json': { + 'access_token': 'test_token', + 'token_type': 'bearer', + 'expires_in': '1'}, + 'headers': {'Content-Type': 'application/json'}, + 'status_code': 200} + res_mock2 = { + 'json': { + 'access_token': 'test_token2', + 'token_type': 'bearer'}, + 'headers': {'Content-Type': 'application/json'}, + 'status_code': 200} + + self.requests_mock.register_uri( + 'GET', + self.token_endpoint_url, [res_mock, res_mock2]) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + client = auth._OAuth2Session(grant) + client.apply_access_token_info() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(2, req_count) + self.assertEqual(1, m.call_count) + + def test_apply_access_token_info_fail_error_response(self): + error_description = """ + Either your username or password is incorrect + or you are not an active user. + Please try again or contact your administrator. + """ + self.requests_mock.register_uri( + 'GET', + self.token_endpoint_url, + headers={ + 'Content-Type': 'application/json;charset=UTF-8', + 'Cache-Control': 'no-store', + 'Pragma': 'no-store', + 'WWW-Authenticate': 'Basic realm="example"'}, + json={ + 'error': 'invalid_client', + 'error_description': error_description}, + status_code=401) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + try: + client = auth._OAuth2Session(grant) + client.apply_access_token_info() + except requests.exceptions.RequestException as e: + self.assertEqual(401, e.response.status_code) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(1, req_count) + self.assertEqual(0, m.call_count) + + def test_apply_access_token_info_fail_timeout(self): + self.requests_mock.register_uri( + 'GET', + self.token_endpoint_url, + exc=requests.exceptions.ConnectTimeout) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + try: + client = auth._OAuth2Session(grant) + client.apply_access_token_info() + except requests.exceptions.RequestException as e: + self.assertIsNone(e.response) + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(1, req_count) + self.assertEqual(0, m.call_count) + + def test_schedule_refrash_token_expaire(self): + self.requests_mock.register_uri( + 'GET', + self.token_endpoint_url, + headers={'Content-Type': 'application/json'}, + json={ + 'access_token': 'test_token', + 'token_type': 'bearer'}, + status_code=200) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + client = auth._OAuth2Session(grant) + client._OAuth2Session__access_token_info.update({ + 'access_token': 'test_token', + 'token_type': 'bearer', + 'expires_in': '1'}) + client.schedule_refrash_token() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(1, req_count) + self.assertEqual(1, m.call_count) + + def test_schedule_refrash_token_non_expaire(self): + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + client = auth._OAuth2Session(grant) + client._OAuth2Session__access_token_info.update({ + 'access_token': 'test_token', + 'token_type': 'bearer'}) + client.schedule_refrash_token() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(0, req_count) + self.assertEqual(0, m.call_count) + + @ddt.data(None, "") + def test_schedule_refrash_token_invalid_value(self, invalid_value): + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + + with mock.patch("threading.Timer", side_effect=self.MockThread) as m: + client = auth._OAuth2Session(grant) + client._OAuth2Session__access_token_info.update({ + 'access_token': 'test_token', + 'token_type': 'bearer', + 'expires_in': invalid_value}) + client.schedule_refrash_token() + + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, + self.oauth_url) + self.assertEqual(0, req_count) + self.assertEqual(0, m.call_count) + + @ddt.data('GET', 'PUT', 'POST', 'DELETE', 'PATCH') + def test_request_client_credentials(self, http_method): + self.requests_mock.register_uri('GET', + self.token_endpoint_url, + json={'access_token': 'test_token3', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + client = auth._OAuth2Session(grant) + client.apply_access_token_info() + + self.requests_mock.register_uri(http_method, + self.oauth_url, + headers={'Content-Type': 'application/json'}, + status_code=200) + + if http_method == 'GET': + response = client.get( + self.oauth_url, + params={ + 'sample_key': 'sample_value'}) + elif http_method == 'PUT': + response = client.put( + self.oauth_url, + data={ + 'sample_key': 'sample_value'}) + elif http_method == 'POST': + response = client.post( + self.oauth_url, + data={ + 'sample_key': 'sample_value'}) + elif http_method == 'DELETE': + response = client.delete( + self.oauth_url, + params={ + 'sample_key': 'sample_value'}) + elif http_method == 'PATCH': + response = client.patch( + self.oauth_url, + data={ + 'sample_key': 'sample_value'}) + + self.assertEqual(200, response.status_code) + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history(history, self.oauth_url) + self.assertEqual(2, req_count) + + def test_request_client_credentials_auth_error(self): + self.requests_mock.register_uri('GET', + self.token_endpoint_url, + json={'access_token': 'test_token3', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + self.requests_mock.register_uri('GET', + "https://nfvo.co.jp", + text="error.", + status_code=401) + + grant = auth._ClientCredentialsGrant( + client_id=self.user_name, + client_password=self.password, + token_endpoint=self.token_endpoint_url) + client = auth._OAuth2Session(grant) + client.apply_access_token_info() + + response = client.get('https://nfvo.co.jp') + + self.assertEqual(401, response.status_code) + history = self.requests_mock.request_history + req_count = nfvo_client._count_mock_history( + history, self.oauth_url, 'https://nfvo.co.jp') + self.assertEqual(3, req_count) diff --git a/tacker/tests/unit/vnfm/test_nfvo_client.py b/tacker/tests/unit/vnfm/test_nfvo_client.py index b6e2030..dd8a438 100644 --- a/tacker/tests/unit/vnfm/test_nfvo_client.py +++ b/tacker/tests/unit/vnfm/test_nfvo_client.py @@ -10,11 +10,13 @@ # License for the specific language governing permissions and limitations # under the License. +import base64 import datetime import hashlib import io import json import os +import requests import shutil import tempfile import uuid @@ -22,9 +24,10 @@ import zipfile import ddt from oslo_config import cfg -import requests +from requests_mock.contrib import fixture as requests_mock_fixture +from tacker import auth +from tacker.tests import base -from tacker.tests.unit import base from tacker.tests.unit.vnfm.infra_drivers.openstack.fixture_data import client from tacker.tests.unit.vnfpkgm import fakes from tacker.tests import utils @@ -44,22 +47,47 @@ def _count_mock_history(history, *url): @ddt.ddt -class TestVnfPackageRequest(base.FixturedTestCase): +class TestVnfPackageRequest(base.BaseTestCase): client_fixture_class = client.ClientFixture sdk_connection_fixure_class = client.SdkConnectionFixture def setUp(self): super(TestVnfPackageRequest, self).setUp() + self.requests_mock = self.useFixture(requests_mock_fixture.Fixture()) self.url = "http://nfvo.co.jp/vnfpkgm/v1/vnf_packages" self.nfvo_url = "http://nfvo.co.jp" self.test_package_dir = 'tacker/tests/unit/vnfm/' self.headers = {'Content-Type': 'application/json'} + self.token_endpoint = 'https://oauth2/tokens' + self.oauth_url = 'https://oauth2' + self.auth_user_name = 'test_user' + self.auth_password = 'test_password' + + cfg.CONF.set_override('auth_type', None, + group='authentication') + auth.auth_manager = auth._AuthManager() + nfvo_client.VnfPackageRequest._connector = nfvo_client._Connect( + 2, 1, 20) + def tearDown(self): super(TestVnfPackageRequest, self).tearDown() self.addCleanup(mock.patch.stopall) + def assert_auth_basic(self, acutual_request): + actual_auth = acutual_request._request.headers.get("Authorization") + expected_auth = base64.b64encode( + '{}:{}'.format( + self.auth_user_name, + self.auth_password).encode('utf-8')).decode() + self.assertEqual("Basic " + expected_auth, actual_auth) + + def assert_auth_client_credentials(self, acutual_request, expected_token): + actual_auth = acutual_request._request.headers.get( + "Authorization") + self.assertEqual("Bearer " + expected_token, actual_auth) + def assert_zipfile( self, actual_zip, @@ -243,6 +271,133 @@ class TestVnfPackageRequest(base.FixturedTestCase): self.assertEqual(expected_connect_cnt, req_count) + def test_download_vnf_packages_with_auth_basic(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'BASIC', + group='authentication') + cfg.CONF.set_override('user_name', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('password', self.auth_password, + group='authentication') + auth.auth_manager = auth._AuthManager() + + expected_connect_cnt = \ + self._download_vnf_packages_all_pipeline_with_assert() + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url) + self.assertEqual(expected_connect_cnt, req_count) + for h in history: + self.assert_auth_basic(h) + + def test_download_vnf_packages_with_auth_client_credentials(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'OAUTH2_CLIENT_CREDENTIALS', + group='authentication') + cfg.CONF.set_override('token_endpoint', self.token_endpoint, + group='authentication') + cfg.CONF.set_override('client_id', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('client_password', self.auth_password, + group='authentication') + + expected_connect_cnt = 1 + self.requests_mock.register_uri('GET', + self.token_endpoint, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + auth.auth_manager = auth._AuthManager() + + expected_connect_cnt += \ + self._download_vnf_packages_all_pipeline_with_assert() + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url, self.oauth_url) + self.assertEqual(expected_connect_cnt, req_count) + self.assert_auth_basic(history[0]) + for h in history[1:]: + self.assert_auth_client_credentials(h, "test_token") + + def _download_vnf_packages_all_pipeline_with_assert(self): + fetch_base_url = os.path.join(self.url, uuidsentinel.vnf_pkg_id) + expected_connect_cnt = 0 + pipelines = [] + + content = 'vnfpkgm1' + expected_connect_cnt += 1 + pipelines.append('package_content') + path = self._make_zip_file_from_sample(content) + with open(path, 'rb') as test_package_content_zip_obj: + expected_package_content_zip = zipfile.ZipFile( + io.BytesIO(test_package_content_zip_obj.read())) + test_package_content_zip_obj.seek(0) + self.requests_mock.register_uri( + 'GET', + os.path.join( + fetch_base_url, + 'package_content'), + content=test_package_content_zip_obj.read(), + headers={ + 'Content-Type': 'application/zip'}, + status_code=200) + + vnfd = 'vnfpkgm2' + expected_connect_cnt += 1 + pipelines.append('vnfd') + path = self._make_zip_file_from_sample(vnfd, read_vnfd_only=True) + with open(path, 'rb') as test_vnfd_zip_obj: + expected_vnfd_zip = zipfile.ZipFile( + io.BytesIO(test_vnfd_zip_obj.read())) + test_vnfd_zip_obj.seek(0) + self.requests_mock.register_uri( + 'GET', + os.path.join( + fetch_base_url, + 'vnfd'), + content=test_vnfd_zip_obj.read(), + headers={ + 'Content-Type': 'application/zip'}, + status_code=200) + + artifacts = ["vnfd_lcm_user_data.yaml"] + pipelines.append('artifacts') + artifacts = [os.path.join("tacker/tests/etc/samples", p) + for p in artifacts] + for artifact_path in artifacts: + expected_connect_cnt += 1 + with open(artifact_path, 'rb') as artifact_path_obj: + self.requests_mock.register_uri( + 'GET', + os.path.join( + fetch_base_url, + 'artifacts', + artifact_path), + headers={ + 'Content-Type': 'application/octet-stream'}, + status_code=200, + content=artifact_path_obj.read()) + + cfg.CONF.set_default( + name='pipeline', + group='connect_vnf_packages', + default=pipelines) + + res = nfvo_client.VnfPackageRequest.download_vnf_packages( + uuidsentinel.vnf_pkg_id, artifacts) + self.assertIsInstance(res, io.BytesIO) + + actual_zip = zipfile.ZipFile(res) + self.assert_zipfile( + actual_zip, [ + expected_package_content_zip, + expected_vnfd_zip], artifacts) + + return expected_connect_cnt + def test_download_vnf_packages_content_disposition(self): cfg.CONF.set_override("base_url", self.url, group='connect_vnf_packages') @@ -313,6 +468,9 @@ class TestVnfPackageRequest(base.FixturedTestCase): self.assertEqual(1, req_count) def test_download_vnf_packages_with_retry_raise_not_found(self): + # TODO(Edagawa) fix duplicated lines + # (cfg.CONF.set_override and cfg.CONF.set_default) with below + # two functions. cfg.CONF.set_override("base_url", self.url, group='connect_vnf_packages') cfg.CONF.set_default( @@ -436,6 +594,74 @@ class TestVnfPackageRequest(base.FixturedTestCase): req_count = _count_mock_history(history, self.nfvo_url) self.assertEqual(1, req_count) + def test_index_with_auth_basic(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'BASIC', + group='authentication') + cfg.CONF.set_override('user_name', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('password', self.auth_password, + group='authentication') + auth.auth_manager = auth._AuthManager() + + response_body = self.json_serial_date_to_dict( + [fakes.VNFPACKAGE_RESPONSE, fakes.VNFPACKAGE_RESPONSE]) + self.requests_mock.register_uri( + 'GET', self.url, headers=self.headers, json=response_body) + + res = nfvo_client.VnfPackageRequest.index() + self.assertEqual(200, res.status_code) + self.assertIsInstance(res.json(), list) + self.assertEqual(response_body, res.json()) + self.assertEqual(2, len(res.json())) + self.assertEqual(response_body, json.loads(res.text)) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url) + self.assertEqual(1, req_count) + self.assert_auth_basic(history[0]) + + def test_index_with_auth_client_credentials(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'OAUTH2_CLIENT_CREDENTIALS', + group='authentication') + cfg.CONF.set_override('token_endpoint', self.token_endpoint, + group='authentication') + cfg.CONF.set_override('client_id', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('client_password', self.auth_password, + group='authentication') + + self.requests_mock.register_uri('GET', + self.token_endpoint, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + auth.auth_manager = auth._AuthManager() + + response_body = self.json_serial_date_to_dict( + [fakes.VNFPACKAGE_RESPONSE, fakes.VNFPACKAGE_RESPONSE]) + self.requests_mock.register_uri( + 'GET', self.url, headers=self.headers, json=response_body) + + res = nfvo_client.VnfPackageRequest.index() + self.assertEqual(200, res.status_code) + self.assertIsInstance(res.json(), list) + self.assertEqual(response_body, res.json()) + self.assertEqual(2, len(res.json())) + self.assertEqual(response_body, json.loads(res.text)) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url, self.oauth_url) + self.assertEqual(2, req_count) + self.assert_auth_basic(history[0]) + self.assert_auth_client_credentials(history[1], "test_token") + def test_index_raise_not_found(self): cfg.CONF.set_override("base_url", self.url, group='connect_vnf_packages') @@ -484,6 +710,82 @@ class TestVnfPackageRequest(base.FixturedTestCase): req_count = _count_mock_history(history, self.nfvo_url) self.assertEqual(1, req_count) + def test_show_with_auth_basic(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'BASIC', + group='authentication') + cfg.CONF.set_override('user_name', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('password', self.auth_password, + group='authentication') + auth.auth_manager = auth._AuthManager() + + response_body = self.json_serial_date_to_dict( + fakes.VNFPACKAGE_RESPONSE) + self.requests_mock.register_uri( + 'GET', + os.path.join( + self.url, + uuidsentinel.vnf_pkg_id), + headers=self.headers, + json=response_body) + + res = nfvo_client.VnfPackageRequest.show(uuidsentinel.vnf_pkg_id) + self.assertEqual(200, res.status_code) + self.assertIsInstance(res.json(), dict) + self.assertEqual(response_body, res.json()) + self.assertEqual(response_body, json.loads(res.text)) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url) + self.assertEqual(1, req_count) + self.assert_auth_basic(history[0]) + + def test_show_with_auth_client_credentials(self): + cfg.CONF.set_override("base_url", self.url, + group='connect_vnf_packages') + + cfg.CONF.set_override('auth_type', 'OAUTH2_CLIENT_CREDENTIALS', + group='authentication') + cfg.CONF.set_override('token_endpoint', self.token_endpoint, + group='authentication') + cfg.CONF.set_override('client_id', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('client_password', self.auth_password, + group='authentication') + + self.requests_mock.register_uri('GET', + self.token_endpoint, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + auth.auth_manager = auth._AuthManager() + + response_body = self.json_serial_date_to_dict( + fakes.VNFPACKAGE_RESPONSE) + self.requests_mock.register_uri( + 'GET', + os.path.join( + self.url, + uuidsentinel.vnf_pkg_id), + headers=self.headers, + json=response_body) + + res = nfvo_client.VnfPackageRequest.show(uuidsentinel.vnf_pkg_id) + self.assertEqual(200, res.status_code) + self.assertIsInstance(res.json(), dict) + self.assertEqual(response_body, res.json()) + self.assertEqual(response_body, json.loads(res.text)) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url, self.oauth_url) + self.assertEqual(2, req_count) + self.assert_auth_basic(history[0]) + self.assert_auth_client_credentials(history[1], "test_token") + def test_show_raise_not_found(self): cfg.CONF.set_override("base_url", self.url, group='connect_vnf_packages') @@ -516,20 +818,43 @@ class TestVnfPackageRequest(base.FixturedTestCase): @ddt.ddt -class TestGrantRequest(base.FixturedTestCase): - client_fixture_class = client.ClientFixture - sdk_connection_fixure_class = client.SdkConnectionFixture +class TestGrantRequest(base.BaseTestCase): def setUp(self): super(TestGrantRequest, self).setUp() + self.requests_mock = self.useFixture(requests_mock_fixture.Fixture()) self.url = "http://nfvo.co.jp/grant/v1/grants" self.nfvo_url = 'http://nfvo.co.jp' self.headers = {'content-type': 'application/json'} + self.token_endpoint = 'https://oauth2/tokens' + self.nfvo_url = 'http://nfvo.co.jp' + self.oauth_url = 'https://oauth2' + self.auth_user_name = 'test_user' + self.auth_password = 'test_password' + + cfg.CONF.set_override('auth_type', None, + group='authentication') + auth.auth_manager = auth._AuthManager() + nfvo_client.GrantRequest._connector = nfvo_client._Connect(2, 1, 20) + def tearDown(self): super(TestGrantRequest, self).tearDown() self.addCleanup(mock.patch.stopall) + def assert_auth_basic(self, acutual_request): + actual_auth = acutual_request._request.headers.get("Authorization") + expected_auth = base64.b64encode( + '{}:{}'.format( + self.auth_user_name, + self.auth_password).encode('utf-8')).decode() + self.assertEqual("Basic " + expected_auth, actual_auth) + + def assert_auth_client_credentials(self, acutual_request, expected_token): + actual_auth = acutual_request._request.headers.get( + "Authorization") + self.assertEqual("Bearer " + expected_token, actual_auth) + def create_request_body(self): return { "vnfInstanceId": uuidsentinel.vnf_instance_id, @@ -630,3 +955,71 @@ class TestGrantRequest(base.FixturedTestCase): self.assertRaises(nfvo_client.UndefinedExternalSettingException, nfvo_client.GrantRequest.grants, data={"test": "value1"}) + + def test_grants_with_auth_basic(self): + cfg.CONF.set_override("base_url", self.url, group='connect_grant') + + cfg.CONF.set_override('auth_type', 'BASIC', + group='authentication') + cfg.CONF.set_override('user_name', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('password', self.auth_password, + group='authentication') + auth.auth_manager = auth._AuthManager() + + response_body = self.fake_response_body() + self.requests_mock.register_uri( + 'POST', + self.url, + json=response_body, + headers=self.headers, + status_code=201) + + request_body = self.create_request_body() + res = nfvo_client.GrantRequest.grants(data=request_body) + self.assertEqual(response_body, json.loads(res.text)) + self.assertEqual(response_body, res.json()) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url) + self.assertEqual(1, req_count) + self.assert_auth_basic(history[0]) + + def test_grants_with_auth_client_credentials(self): + cfg.CONF.set_override("base_url", self.url, group='connect_grant') + + cfg.CONF.set_override('auth_type', 'OAUTH2_CLIENT_CREDENTIALS', + group='authentication') + cfg.CONF.set_override('token_endpoint', self.token_endpoint, + group='authentication') + cfg.CONF.set_override('client_id', self.auth_user_name, + group='authentication') + cfg.CONF.set_override('client_password', self.auth_password, + group='authentication') + + self.requests_mock.register_uri('GET', + self.token_endpoint, + json={'access_token': 'test_token', 'token_type': 'bearer'}, + headers={'Content-Type': 'application/json'}, + status_code=200) + + auth.auth_manager = auth._AuthManager() + + response_body = self.fake_response_body() + self.requests_mock.register_uri( + 'POST', + self.url, + json=response_body, + headers=self.headers, + status_code=201) + + request_body = self.create_request_body() + res = nfvo_client.GrantRequest.grants(data=request_body) + self.assertEqual(response_body, json.loads(res.text)) + self.assertEqual(response_body, res.json()) + + history = self.requests_mock.request_history + req_count = _count_mock_history(history, self.nfvo_url, self.oauth_url) + self.assertEqual(2, req_count) + self.assert_auth_basic(history[0]) + self.assert_auth_client_credentials(history[1], "test_token") diff --git a/tacker/vnfm/nfvo_client.py b/tacker/vnfm/nfvo_client.py index 65b82b9..f3e3b2c 100644 --- a/tacker/vnfm/nfvo_client.py +++ b/tacker/vnfm/nfvo_client.py @@ -13,6 +13,7 @@ import io import os import requests +from tacker import auth import time import zipfile @@ -46,7 +47,7 @@ class _Connect: def request(self, *args, **kwargs): return self.__request( - requests.Session().request, + auth.auth_manager.get_auth_client().request, *args, timeout=self.timeout, **kwargs)