Move token auto-refresh

This takes advantage of existing expiry checks and works more like if
OAuth2Session.request() was used.

Probably won't refresh revoked tokens as that would be noticed in the
request itself which is outside the control of this plugin.
1 files changed, 45 insertions(+), 28 deletions(-)

M httpie_oauth2.py
M httpie_oauth2.py +45 -28
@@ 113,6 113,42 @@ class OAuth2Client(WebApplicationClient)
         )
 
 
+class AutoRefreshingOAuth2(OAuth2):
+    session: OAuth2Session = None
+
+    def __init__(self, session, **kwargs):
+        self.session = session
+        super().__init__(**kwargs)
+
+    def __call__(self, r):
+        try:
+            if oauth2_arguments["token_refresh"] == "always":
+                raise TokenExpiredError
+            return super().__call__(r)
+        except TokenExpiredError:
+            if not self.session.auto_refresh_url:
+                raise
+
+            try:
+                token_info = self.session.refresh_token(
+                    token_url=self.session.auto_refresh_url,
+                    auth=self._client.get_auth(),
+                    headers={  # XXX workaround for https://github.com/requests/requests-oauthlib/issues/437
+                        "Accept": "application/json",
+                        "Content-Type": "application/x-www-form-urlencoded",
+                    },
+                )
+                self._client.token = token_info
+                if self.session.token_updater:
+                    self.session.token_updater(token_info)
+                return super().__call__(r)
+            except InvalidGrantError:
+                if self.session.token_updater:
+                    self.session.token_updater(None)
+                raise
+            raise
+
+
 class OAuth2RedirectHandler(BaseHTTPRequestHandler):
     def do_GET(self):
         url = urlparse(self.path)

          
@@ 326,40 362,19 @@ class OAuth2Plugin(AuthPlugin):
 
         client = authz_server.get_client()
 
+        def token_updater(token):
+            authz_server["token_info"] = token
+            authz_server.save()
+
         session = OAuth2Session(
             client=client,
             scope=oauth2_arguments.get("scope"),
             redirect_uri=OAUTH2_REDIRECT_URI,
+            auto_refresh_url=authz_meta["token_endpoint"],
+            token_updater=token_updater,
             token=authz_server["token_info"],
         )
 
-        if authz_server["token_info"] and "expires_at" in authz_server["token_info"]:
-            if (
-                time() > authz_server["token_info"]["expires_at"]
-                or oauth2_arguments["token_refresh"] == "always"
-            ):
-                if not "refresh_token" in authz_server["token_info"]:
-                    # Can't refresh without a refresh token
-                    raise TokenExpiredError()
-
-                try:
-                    token_info = session.refresh_token(
-                        token_url=authz_meta["token_endpoint"],
-                        auth=client.get_auth(),
-                        headers={  # XXX workaround for https://github.com/requests/requests-oauthlib/issues/437
-                            "Accept": "application/json",
-                            "Content-Type": "application/x-www-form-urlencoded",
-                        },
-                    )
-                    authz_server["token_info"] = token_info
-                    authz_server.save()
-                except InvalidGrantError:
-                    # the grant could be revoked or expired
-                    authz_server["token_info"] = None
-                    authz_server.save()
-
-        # FIXME if the token is revoked before it expires we would only see this as 401/403 when
-
         if not authz_server["token_info"]:
             challenge_method = "S256"
             verifier = client.create_code_verifier(43)

          
@@ 400,4 415,6 @@ class OAuth2Plugin(AuthPlugin):
             httpd.shutdown()
             httpt.join()
 
-        return OAuth2(client_id=client.client_id, token=authz_server["token_info"])
+        return AutoRefreshingOAuth2(
+            client=client, session=session, token=authz_server["token_info"]
+        )