Project

General

Profile

Download (8.25 KB) Statistics
| Branch: | Tag: | Revision:

oidc / ckanext / ozwillo_pyoidc / oidc.py @ c8204b73

1
from oic.utils.http_util import Redirect
2
from oic.exception import MissingAttribute
3
from oic import oic
4
from oic.oauth2 import rndstr, ErrorResponse
5
from oic.oic import ProviderConfigurationResponse, AuthorizationResponse
6
from oic.oic import RegistrationResponse
7
from oic.oic import AuthorizationRequest
8
from oic.utils.authn.client import CLIENT_AUTHN_METHOD
9

    
10
__author__ = 'roland'
11

    
12
import logging
13

    
14
logger = logging.getLogger(__name__)
15

    
16

    
17
class OIDCError(Exception):
18
    pass
19

    
20

    
21
class Client(oic.Client):
22
    def __init__(self, client_id=None, ca_certs=None,
23
                 client_prefs=None, client_authn_method=None, keyjar=None,
24
                 verify_ssl=True, behaviour=None):
25
        oic.Client.__init__(self, client_id, ca_certs, client_prefs,
26
                            client_authn_method, keyjar, verify_ssl)
27
        if behaviour:
28
            self.behaviour = behaviour
29

    
30
    def create_authn_request(self, session, acr_value=None):
31
        session["state"] = rndstr()
32
        session["nonce"] = rndstr()
33
        request_args = {
34
            "response_type": self.behaviour["response_type"],
35
            "scope": self.behaviour["scope"],
36
            "state": session["state"],
37
            # "nonce": session["nonce"],
38
            "redirect_uri": self.registration_response["redirect_uris"][0]
39
        }
40

    
41
        if acr_value is not None:
42
            request_args["acr_values"] = acr_value
43

    
44
        cis = self.construct_AuthorizationRequest(request_args=request_args)
45
        logger.debug("request: %s" % cis)
46

    
47
        url, body, ht_args, cis = self.uri_and_body(AuthorizationRequest, cis,
48
                                                    method="GET",
49
                                                    request_args=request_args)
50

    
51
        logger.debug("body: %s" % body)
52
        logger.info("URL: %s" % url)
53
        logger.debug("ht_args: %s" % ht_args)
54

    
55
        return str(url), ht_args
56

    
57
    def callback(self, response):
58
        """
59
        This is the method that should be called when an AuthN response has been
60
        received from the OP.
61

    
62
        :param response: The URL returned by the OP
63
        :return:
64
        """
65
        authresp = self.parse_response(AuthorizationResponse, response,
66
                                       sformat="dict", keyjar=self.keyjar)
67

    
68
        if isinstance(authresp, ErrorResponse):
69
            return OIDCError("Access denied")
70

    
71
        try:
72
            self.id_token[authresp["state"]] = authresp["id_token"]
73
        except KeyError:
74
            pass
75

    
76
        if self.behaviour["response_type"] == "code":
77
            # get the access token
78
            try:
79
                args = {
80
                    "grant_type": "authorization_code",
81
                    "code": authresp["code"],
82
                    "redirect_uri": self.registration_response[
83
                        "redirect_uris"][0],
84
                    "client_id": self.client_id,
85
                    "client_secret": self.client_secret
86
                }
87

    
88
                atresp = self.do_access_token_request(
89
                    scope="openid", state=authresp["state"], request_args=args,
90
                    authn_method=self.registration_response["token_endpoint_auth_method"])
91
            except Exception as err:
92
                logger.error("%s" % err)
93
                raise
94

    
95
            if isinstance(atresp, ErrorResponse):
96
                raise OIDCError("Invalid response %s." % atresp["error"])
97

    
98
        inforesp = self.do_user_info_request(state=authresp["state"])
99

    
100
        if isinstance(inforesp, ErrorResponse):
101
            raise OIDCError("Invalid response %s." % inforesp["error"])
102

    
103
        userinfo = inforesp.to_dict()
104

    
105
        logger.debug("UserInfo: %s" % inforesp)
106

    
107
        return userinfo
108

    
109

    
110
class OIDCClients(object):
111
    def __init__(self, config):
112
        """
113

    
114
        :param config: Imported configuration module
115
        :return:
116
        """
117
        self.client = {}
118
        self.client_cls = Client
119
        self.config = config
120

    
121
        for key, val in config.CLIENTS.items():
122
            if key == "":
123
                continue
124
            else:
125
                self.client[key] = self.create_client(**val)
126

    
127
    def create_client(self, userid="", **kwargs):
128
        """
129
        Do an instantiation of a client instance
130

    
131
        :param userid: An identifier of the user
132
        :param: Keyword arguments
133
            Keys are ["srv_discovery_url", "client_info", "client_registration",
134
            "provider_info"]
135
        :return: client instance
136
        """
137

    
138
        _key_set = set(kwargs.keys())
139
        args = {}
140
        for param in ["verify_ssl"]:
141
            try:
142
                args[param] = kwargs[param]
143
            except KeyError:
144
                pass
145
            else:
146
                _key_set.discard(param)
147

    
148
        client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD,
149
                                 behaviour=kwargs["behaviour"], verify_ssl=self.config.VERIFY_SSL, **args)
150

    
151
        # The behaviour parameter is not significant for the election process
152
        _key_set.discard("behaviour")
153
        for param in ["allow"]:
154
            try:
155
                setattr(client, param, kwargs[param])
156
            except KeyError:
157
                pass
158
            else:
159
                _key_set.discard(param)
160

    
161
        if _key_set == set(["client_info"]):  # Everything dynamic
162
            # There has to be a userid
163
            if not userid:
164
                raise MissingAttribute("Missing userid specification")
165

    
166
            # Find the service that provides information about the OP
167
            issuer = client.wf.discovery_query(userid)
168
            # Gather OP information
169
            _ = client.provider_config(issuer)
170
            # register the client
171
            _ = client.register(client.provider_info["registration_endpoint"],
172
                                **kwargs["client_info"])
173
        elif _key_set == set(["client_info", "srv_discovery_url"]):
174
            # Ship the webfinger part
175
            # Gather OP information
176
            _ = client.provider_config(kwargs["srv_discovery_url"])
177
            # register the client
178
            _ = client.register(client.provider_info["registration_endpoint"],
179
                                **kwargs["client_info"])
180
        elif _key_set == set(["provider_info", "client_info"]):
181
            client.handle_provider_config(
182
                ProviderConfigurationResponse(**kwargs["provider_info"]),
183
                kwargs["provider_info"]["issuer"])
184
            _ = client.register(client.provider_info["registration_endpoint"],
185
                                **kwargs["client_info"])
186
        elif _key_set == set(["provider_info", "client_registration"]):
187
            client.handle_provider_config(
188
                ProviderConfigurationResponse(**kwargs["provider_info"]),
189
                kwargs["provider_info"]["issuer"])
190
            client.store_registration_info(RegistrationResponse(
191
                **kwargs["client_registration"]))
192
        elif _key_set == set(["srv_discovery_url", "client_registration"]):
193
            _ = client.provider_config(kwargs["srv_discovery_url"])
194
            client.store_registration_info(RegistrationResponse(
195
                **kwargs["client_registration"]))
196
        else:
197
            raise Exception("Configuration error ?")
198

    
199
        return client
200

    
201
    def dynamic_client(self, userid):
202
        client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD,
203
                                 verify_ssl=self.config.VERIFY_SSL)
204

    
205
        issuer = client.wf.discovery_query(userid)
206
        if issuer in self.client:
207
            return self.client[issuer]
208
        else:
209
            # Gather OP information
210
            _pcr = client.provider_config(issuer)
211
            # register the client
212
            _ = client.register(_pcr["registration_endpoint"],
213
                                **self.config.CLIENTS[""]["client_info"])
214
            try:
215
                client.behaviour.update(**self.config.CLIENTS[""]["behaviour"])
216
            except KeyError:
217
                pass
218

    
219
            self.client[issuer] = client
220
            return client
221

    
222
    def __getitem__(self, item):
223
        """
224
        Given a service or user identifier return a suitable client
225
        :param item:
226
        :return:
227
        """
228
        try:
229
            return self.client[item]
230
        except KeyError:
231
            return self.dynamic_client(item)
232

    
233
    def keys(self):
234
        return self.client.keys()
(3-3/4)