Project

General

Profile

Download (6.75 KB) Statistics
| Branch: | Tag: | Revision:

oidc / ckanext / ozwillo_pyoidc / oidc.py @ 2bfcb919

1
from oic.exception import MissingAttribute
2
from oic import oic
3
from oic.oauth2 import rndstr, ErrorResponse
4
from oic.oic import ProviderConfigurationResponse, AuthorizationResponse
5
from oic.oic import RegistrationResponse
6
from oic.oic import AuthorizationRequest
7
from oic.utils.authn.client import CLIENT_AUTHN_METHOD
8

    
9
import logging
10

    
11
logger = logging.getLogger(__name__)
12

    
13
import conf
14

    
15
class OIDCError(Exception):
16
    pass
17

    
18

    
19
class Client(oic.Client):
20
    def __init__(self, client_id=None, client_secret=None, ca_certs=None,
21
                 client_prefs=None, client_authn_method=None, keyjar=None,
22
                 verify_ssl=True, behaviour=None):
23
        oic.Client.__init__(self, client_id, client_secret, ca_certs,
24
                            client_prefs, client_authn_method,
25
                            keyjar, verify_ssl)
26
        if behaviour:
27
            self.behaviour = behaviour
28

    
29
    def create_authn_request(self, acr_value=None):
30
        self.state = rndstr()
31
        nonce = rndstr()
32
        request_args = {
33
            "response_type": self.behaviour["response_type"],
34
            "scope": self.behaviour["scope"],
35
            "state": self.state,
36
            "nonce": nonce,
37
            "redirect_uri": self.registration_response["redirect_uris"][0]
38
        }
39

    
40
        if acr_value is not None:
41
            request_args["acr_values"] = acr_value
42

    
43
        cis = self.construct_AuthorizationRequest(request_args=request_args)
44
        logger.debug("request: %s" % cis)
45

    
46
        url, body, ht_args, cis = self.uri_and_body(AuthorizationRequest, cis,
47
                                                    method="GET",
48
                                                    request_args=request_args)
49

    
50
        logger.debug("body: %s" % body)
51
        logger.info("URL: %s" % url)
52
        logger.debug("ht_args: %s" % ht_args)
53

    
54
        return str(url), ht_args
55

    
56
    def callback(self, response):
57
        """
58
        This is the method that should be called when an AuthN response has been
59
        received from the OP.
60

    
61
        :param response: The URL returned by the OP
62
        :return:
63
        """
64
        authresp = self.parse_response(AuthorizationResponse, response,
65
                                       sformat="dict", keyjar=self.keyjar)
66

    
67
        if self.state != authresp['state']:
68
            raise OIDCError("Invalid state %s." % authresp["state"])
69

    
70
        if isinstance(authresp, ErrorResponse):
71
            return OIDCError("Access denied")
72

    
73
        try:
74
            self.id_token[authresp["state"]] = authresp["id_token"]
75
        except KeyError:
76
            pass
77

    
78
        if self.behaviour["response_type"] == "code":
79
            # get the access token
80
            try:
81
                args = {
82
                    "grant_type": "authorization_code",
83
                    "code": authresp["code"],
84
                    "redirect_uri": self.registration_response[
85
                        "redirect_uris"][0],
86
                    "client_id": self.client_id,
87
                    "client_secret": self.client_secret
88
                }
89

    
90
                atresp = self.do_access_token_request(
91
                    scope="openid", state=authresp["state"], request_args=args,
92
                    authn_method=self.registration_response["token_endpoint_auth_method"])
93
                id_token = atresp['id_token']
94
                self.app_admin = 'app_admin' in id_token and id_token['app_admin']
95
                self.app_user = 'app_user' in id_token  and id_token['app_user']
96
            except Exception as err:
97
                logger.error("%s" % err)
98
                raise
99

    
100
            if isinstance(atresp, ErrorResponse):
101
                raise OIDCError("Invalid response %s." % atresp["error"])
102

    
103
        inforesp = self.do_user_info_request(state=authresp["state"],
104
                                             behavior='use_authorization_header')
105

    
106
        if isinstance(inforesp, ErrorResponse):
107
            raise OIDCError("Invalid response %s." % inforesp["error"])
108

    
109
        userinfo = inforesp.to_dict()
110

    
111
        logger.debug("UserInfo: %s" % inforesp)
112

    
113
        return userinfo
114

    
115
def create_client(**kwargs):
116
    """
117
    kwargs = config.CLIENT.iteritems
118
    """
119
    _key_set = set(kwargs.keys())
120
    args = {}
121
    for param in ["verify_ssl", "client_id", "client_secret"]:
122
        try:
123
            args[param] = kwargs[param]
124
        except KeyError:
125
            try:
126
                args[param] = kwargs['client_registration'][param]
127
            except KeyError:
128
                pass
129
        else:
130
            _key_set.discard(param)
131

    
132
    client = Client(client_authn_method=CLIENT_AUTHN_METHOD,
133
                    behaviour=kwargs["behaviour"],
134
                    verify_ssl=conf.VERIFY_SSL, **args)
135

    
136
    # The behaviour parameter is not significant for the election process
137
    _key_set.discard("behaviour")
138
    for param in ["allow"]:
139
        try:
140
            setattr(client, param, kwargs[param])
141
        except KeyError:
142
            pass
143
        else:
144
            _key_set.discard(param)
145

    
146
    if _key_set == set(["client_info"]):  # Everything dynamic
147
        # There has to be a userid
148
        if not userid:
149
            raise MissingAttribute("Missing userid specification")
150

    
151
        # Find the service that provides information about the OP
152
        issuer = client.wf.discovery_query(userid)
153
        # Gather OP information
154
        _ = client.provider_config(issuer)
155
        # register the client
156
        _ = client.register(client.provider_info["registration_endpoint"],
157
                            **kwargs["client_info"])
158
    elif _key_set == set(["client_info", "srv_discovery_url"]):
159
        # Ship the webfinger part
160
        # Gather OP information
161
        _ = client.provider_config(kwargs["srv_discovery_url"])
162
        # register the client
163
        _ = client.register(client.provider_info["registration_endpoint"],
164
                            **kwargs["client_info"])
165
    elif _key_set == set(["provider_info", "client_info"]):
166
        client.handle_provider_config(
167
            ProviderConfigurationResponse(**kwargs["provider_info"]),
168
            kwargs["provider_info"]["issuer"])
169
        _ = client.register(client.provider_info["registration_endpoint"],
170
                            **kwargs["client_info"])
171
    elif _key_set == set(["provider_info", "client_registration"]):
172
        client.handle_provider_config(
173
            ProviderConfigurationResponse(**kwargs["provider_info"]),
174
            kwargs["provider_info"]["issuer"])
175
        client.store_registration_info(RegistrationResponse(
176
            **kwargs["client_registration"]))
177
    elif _key_set == set(["srv_discovery_url", "client_registration"]):
178
        _ = client.provider_config(kwargs["srv_discovery_url"])
179
        client.store_registration_info(RegistrationResponse(
180
            **kwargs["client_registration"]))
181
    else:
182
        raise Exception("Configuration error ?")
183

    
184
    return client
(3-3/4)