feat(auth-google,auth-github): Allow passing a custom callbackUrl to … (#10829)

* feat(auth-google,auth-github): Allow passing a custom callbackUrl to oauth providers

* feat: Add state management in auth providers

* chore: Changes based on PR review
This commit is contained in:
Stevche Radevski
2025-01-06 17:33:29 +01:00
committed by GitHub
parent 9490c265b2
commit fde73dbfae
8 changed files with 289 additions and 68 deletions

View File

@@ -1,4 +1,4 @@
import { generateJwtToken, MedusaError } from "@medusajs/framework/utils"
import { MedusaError } from "@medusajs/framework/utils"
import { GithubAuthService } from "../../src/services/github"
import { http, HttpResponse } from "msw"
import { setupServer } from "msw/node"
@@ -20,6 +20,22 @@ const sampleIdPayload = {
}
const baseUrl = "https://someurl.com"
const callbackUrl = encodeURIComponent(
"https://someurl.com/auth/github/callback"
)
let state = {}
const defaultSpies = {
retrieve: jest.fn(),
create: jest.fn(),
update: jest.fn(),
setState: jest.fn().mockImplementation((key, value) => {
state[key] = value
}),
getState: jest.fn().mockImplementation((key) => {
return Promise.resolve(state[key])
}),
}
// This is just a network-layer mocking, it doesn't start an actual server
const server = setupServer(
@@ -29,7 +45,7 @@ const server = setupServer(
const url = request.url
if (
url ===
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}`
) {
return new HttpResponse(null, {
status: 401,
@@ -39,7 +55,7 @@ const server = setupServer(
if (
url ===
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}`
) {
return new HttpResponse(
JSON.stringify({
@@ -91,6 +107,7 @@ describe("Github auth provider", () => {
afterEach(() => {
server.resetHandlers()
jest.restoreAllMocks()
state = {}
})
afterAll(() => server.close())
@@ -110,11 +127,27 @@ describe("Github auth provider", () => {
})
it("returns a redirect URL on authenticate", async () => {
const res = await githubService.authenticate({})
const res = await githubService.authenticate({}, defaultSpies)
expect(res).toEqual({
success: true,
location:
"https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback&client_id=test&response_type=code",
location: `https://github.com/login/oauth/authorize?redirect_uri=${callbackUrl}&client_id=test&response_type=code&state=${
Object.keys(state)[0]
}`,
})
})
it("returns a custom redirect_uri on authenticate", async () => {
const res = await githubService.authenticate(
{
body: { callback_url: "https://someotherurl.com" },
},
defaultSpies
)
expect(res).toEqual({
success: true,
location: `https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&state=${
Object.keys(state)[0]
}`,
})
})
@@ -123,7 +156,7 @@ describe("Github auth provider", () => {
{
query: {},
},
{} as any
defaultSpies
)
expect(res).toEqual({
success: false,
@@ -131,14 +164,51 @@ describe("Github auth provider", () => {
})
})
it("validate callback should return an error on missing state", async () => {
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
},
},
defaultSpies
)
expect(res).toEqual({
success: false,
error: "No state provided, or session expired",
})
})
it("validate callback should return an error on expired/invalid state", async () => {
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
defaultSpies
)
expect(res).toEqual({
success: false,
error: "No state provided, or session expired",
})
})
it("validate callback should return on a missing access token for code", async () => {
state = {
somekey: {
callback_url: callbackUrl,
},
}
const res = await githubService.validateCallback(
{
query: {
code: "invalid-code",
state: "somekey",
},
},
{} as any
defaultSpies
)
expect(res).toEqual({
@@ -149,6 +219,7 @@ describe("Github auth provider", () => {
it("validate callback should return successfully on a correct code for a new user", async () => {
const authServiceSpies = {
...defaultSpies,
retrieve: jest.fn().mockImplementation(() => {
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found")
}),
@@ -167,10 +238,17 @@ describe("Github auth provider", () => {
}),
}
state = {
somekey: {
callback_url: callbackUrl,
},
}
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
authServiceSpies
@@ -191,6 +269,7 @@ describe("Github auth provider", () => {
it("validate callback should return successfully on a correct code for an existing user", async () => {
const authServiceSpies = {
...defaultSpies,
retrieve: jest.fn().mockImplementation(() => {
return {
provider_identities: [
@@ -219,10 +298,17 @@ describe("Github auth provider", () => {
}),
}
state = {
somekey: {
callback_url: callbackUrl,
},
}
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
authServiceSpies

View File

@@ -1,3 +1,4 @@
import crypto from "crypto"
import {
AuthenticationInput,
AuthenticationResponse,
@@ -16,7 +17,6 @@ type InjectedDependencies = {
interface LocalServiceConfig extends GithubAuthProviderOptions {}
// TODO: Add state param that is stored in Redis, to prevent CSRF attacks
export class GithubAuthService extends AbstractAuthModuleProvider {
static identifier = "github"
static DISPLAY_NAME = "Github Authentication"
@@ -56,37 +56,53 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
}
async authenticate(
req: AuthenticationInput
req: AuthenticationInput,
authIdentityService: AuthIdentityProviderService
): Promise<AuthenticationResponse> {
if (req.query?.error) {
const query: Record<string, string> = req.query ?? {}
const body: Record<string, string> = req.body ?? {}
if (query.error) {
return {
success: false,
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
error: `${query.error_description}, read more at: ${query.error_uri}`,
}
}
return this.getRedirect(this.config_)
const stateKey = crypto.randomBytes(32).toString("hex")
const state = {
callback_url: body?.callback_url ?? this.config_.callbackUrl,
}
await authIdentityService.setState(stateKey, state)
return this.getRedirect(this.config_.clientId, state.callback_url, stateKey)
}
async validateCallback(
req: AuthenticationInput,
authIdentityService: AuthIdentityProviderService
): Promise<AuthenticationResponse> {
if (req.query && req.query.error) {
const query: Record<string, string> = req.query ?? {}
const body: Record<string, string> = req.body ?? {}
if (query.error) {
return {
success: false,
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
error: `${query.error_description}, read more at: ${query.error_uri}`,
}
}
const code = req.query?.code ?? req.body?.code
const code = query?.code ?? body?.code
if (!code) {
return { success: false, error: "No code provided" }
}
const params = `client_id=${this.config_.clientId}&client_secret=${
this.config_.clientSecret
}&code=${code}&redirect_uri=${encodeURIComponent(this.config_.callbackUrl)}`
const state = await authIdentityService.getState(query?.state as string)
if (!state) {
return { success: false, error: "No state provided, or session expired" }
}
const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}`
const exchangeTokenUrl = new URL(
`https://github.com/login/oauth/access_token?${params}`
@@ -192,18 +208,12 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
}
}
private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) {
const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}`
const clientIdParam = `client_id=${clientId}`
const responseTypeParam = "response_type=code"
const authUrl = new URL(
`https://github.com/login/oauth/authorize?${[
redirectUrlParam,
clientIdParam,
responseTypeParam,
].join("&")}`
)
private getRedirect(clientId: string, callbackUrl: string, stateKey: string) {
const authUrl = new URL(`https://github.com/login/oauth/authorize`)
authUrl.searchParams.set("redirect_uri", callbackUrl)
authUrl.searchParams.set("client_id", clientId)
authUrl.searchParams.set("response_type", "code")
authUrl.searchParams.set("state", stateKey)
return { success: true, location: authUrl.toString() }
}