feat(auth-google,auth-github): Allow passing a custom callbackUrl to … (#10829)
* feat(auth-google,auth-github): Allow passing a custom callbackUrl to oauth providers * feat: Add state management in auth providers * chore: Changes based on PR review
This commit is contained in:
@@ -92,7 +92,7 @@ export const ModulesDefinition: {
|
||||
label: upperCaseFirst(Modules.AUTH),
|
||||
isRequired: false,
|
||||
isQueryable: true,
|
||||
dependencies: [ContainerRegistrationKeys.LOGGER],
|
||||
dependencies: [ContainerRegistrationKeys.LOGGER, Modules.CACHE],
|
||||
defaultModuleDeclaration: {
|
||||
scope: MODULE_SCOPE.INTERNAL,
|
||||
},
|
||||
|
||||
@@ -59,6 +59,10 @@ export type AuthenticationInput = {
|
||||
|
||||
/**
|
||||
* Body of the incoming authentication request.
|
||||
*
|
||||
* One of the arguments that is suggested to be treated in a standard manner is a `callback_url` field.
|
||||
* The field specifies where the user is redirected to after a successful authentication in the case of Oauth auhentication.
|
||||
* If not passed, the provider will fallback to the callback_url provided in the provider options.
|
||||
*/
|
||||
body?: Record<string, string>
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ export interface AuthIdentityProviderService {
|
||||
user_metadata?: Record<string, unknown>
|
||||
}
|
||||
) => Promise<AuthIdentityDTO>
|
||||
// These methods are used for OAuth providers to store and retrieve state
|
||||
setState: (key: string, value: Record<string, unknown>) => Promise<void>
|
||||
getState: (key: string) => Promise<Record<string, unknown> | null>
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
AuthTypes,
|
||||
Context,
|
||||
DAL,
|
||||
ICacheService,
|
||||
InferEntityType,
|
||||
InternalModuleDeclaration,
|
||||
Logger,
|
||||
@@ -27,6 +28,7 @@ type InjectedDependencies = {
|
||||
providerIdentityService: ModulesSdkTypes.IMedusaInternalService<any>
|
||||
authProviderService: AuthProviderService
|
||||
logger?: Logger
|
||||
cache?: ICacheService
|
||||
}
|
||||
export default class AuthModuleService
|
||||
extends MedusaService<{
|
||||
@@ -43,13 +45,14 @@ export default class AuthModuleService
|
||||
InferEntityType<typeof ProviderIdentity>
|
||||
>
|
||||
protected readonly authProviderService_: AuthProviderService
|
||||
|
||||
protected readonly cache_: ICacheService | undefined
|
||||
constructor(
|
||||
{
|
||||
authIdentityService,
|
||||
providerIdentityService,
|
||||
authProviderService,
|
||||
baseRepository,
|
||||
cache,
|
||||
}: InjectedDependencies,
|
||||
protected readonly moduleDeclaration: InternalModuleDeclaration
|
||||
) {
|
||||
@@ -60,6 +63,7 @@ export default class AuthModuleService
|
||||
this.authIdentityService_ = authIdentityService
|
||||
this.authProviderService_ = authProviderService
|
||||
this.providerIdentityService_ = providerIdentityService
|
||||
this.cache_ = cache
|
||||
}
|
||||
|
||||
__joinerConfig(): ModuleJoinerConfig {
|
||||
@@ -372,6 +376,27 @@ export default class AuthModuleService
|
||||
|
||||
return serializedResponse
|
||||
},
|
||||
setState: async (key: string, value: Record<string, unknown>) => {
|
||||
if (!this.cache_) {
|
||||
throw new MedusaError(
|
||||
MedusaError.Types.INVALID_ARGUMENT,
|
||||
"Cache module dependency is required when using OAuth providers that require state"
|
||||
)
|
||||
}
|
||||
|
||||
// 20 minutes. Can be made configurable if necessary, but this is a good default.
|
||||
this.cache_.set(key, value, 1200)
|
||||
},
|
||||
getState: async (key: string) => {
|
||||
if (!this.cache_) {
|
||||
throw new MedusaError(
|
||||
MedusaError.Types.INVALID_ARGUMENT,
|
||||
"Cache module dependency is required when using OAuth providers that require state"
|
||||
)
|
||||
}
|
||||
|
||||
return await this.cache_.get(key)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { generateJwtToken, MedusaError } from "@medusajs/framework/utils"
|
||||
import { MedusaError } from "@medusajs/framework/utils"
|
||||
import { GithubAuthService } from "../../src/services/github"
|
||||
import { http, HttpResponse } from "msw"
|
||||
import { setupServer } from "msw/node"
|
||||
@@ -20,6 +20,22 @@ const sampleIdPayload = {
|
||||
}
|
||||
|
||||
const baseUrl = "https://someurl.com"
|
||||
const callbackUrl = encodeURIComponent(
|
||||
"https://someurl.com/auth/github/callback"
|
||||
)
|
||||
|
||||
let state = {}
|
||||
const defaultSpies = {
|
||||
retrieve: jest.fn(),
|
||||
create: jest.fn(),
|
||||
update: jest.fn(),
|
||||
setState: jest.fn().mockImplementation((key, value) => {
|
||||
state[key] = value
|
||||
}),
|
||||
getState: jest.fn().mockImplementation((key) => {
|
||||
return Promise.resolve(state[key])
|
||||
}),
|
||||
}
|
||||
|
||||
// This is just a network-layer mocking, it doesn't start an actual server
|
||||
const server = setupServer(
|
||||
@@ -29,7 +45,7 @@ const server = setupServer(
|
||||
const url = request.url
|
||||
if (
|
||||
url ===
|
||||
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
|
||||
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}`
|
||||
) {
|
||||
return new HttpResponse(null, {
|
||||
status: 401,
|
||||
@@ -39,7 +55,7 @@ const server = setupServer(
|
||||
|
||||
if (
|
||||
url ===
|
||||
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
|
||||
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}`
|
||||
) {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
@@ -91,6 +107,7 @@ describe("Github auth provider", () => {
|
||||
afterEach(() => {
|
||||
server.resetHandlers()
|
||||
jest.restoreAllMocks()
|
||||
state = {}
|
||||
})
|
||||
|
||||
afterAll(() => server.close())
|
||||
@@ -110,11 +127,27 @@ describe("Github auth provider", () => {
|
||||
})
|
||||
|
||||
it("returns a redirect URL on authenticate", async () => {
|
||||
const res = await githubService.authenticate({})
|
||||
const res = await githubService.authenticate({}, defaultSpies)
|
||||
expect(res).toEqual({
|
||||
success: true,
|
||||
location:
|
||||
"https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback&client_id=test&response_type=code",
|
||||
location: `https://github.com/login/oauth/authorize?redirect_uri=${callbackUrl}&client_id=test&response_type=code&state=${
|
||||
Object.keys(state)[0]
|
||||
}`,
|
||||
})
|
||||
})
|
||||
|
||||
it("returns a custom redirect_uri on authenticate", async () => {
|
||||
const res = await githubService.authenticate(
|
||||
{
|
||||
body: { callback_url: "https://someotherurl.com" },
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: true,
|
||||
location: `https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&state=${
|
||||
Object.keys(state)[0]
|
||||
}`,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -123,7 +156,7 @@ describe("Github auth provider", () => {
|
||||
{
|
||||
query: {},
|
||||
},
|
||||
{} as any
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
@@ -131,14 +164,51 @@ describe("Github auth provider", () => {
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return an error on missing state", async () => {
|
||||
const res = await githubService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
},
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
error: "No state provided, or session expired",
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return an error on expired/invalid state", async () => {
|
||||
const res = await githubService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
error: "No state provided, or session expired",
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return on a missing access token for code", async () => {
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
const res = await githubService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "invalid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
{} as any
|
||||
defaultSpies
|
||||
)
|
||||
|
||||
expect(res).toEqual({
|
||||
@@ -149,6 +219,7 @@ describe("Github auth provider", () => {
|
||||
|
||||
it("validate callback should return successfully on a correct code for a new user", async () => {
|
||||
const authServiceSpies = {
|
||||
...defaultSpies,
|
||||
retrieve: jest.fn().mockImplementation(() => {
|
||||
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found")
|
||||
}),
|
||||
@@ -167,10 +238,17 @@ describe("Github auth provider", () => {
|
||||
}),
|
||||
}
|
||||
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
|
||||
const res = await githubService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
authServiceSpies
|
||||
@@ -191,6 +269,7 @@ describe("Github auth provider", () => {
|
||||
|
||||
it("validate callback should return successfully on a correct code for an existing user", async () => {
|
||||
const authServiceSpies = {
|
||||
...defaultSpies,
|
||||
retrieve: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
provider_identities: [
|
||||
@@ -219,10 +298,17 @@ describe("Github auth provider", () => {
|
||||
}),
|
||||
}
|
||||
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
|
||||
const res = await githubService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
authServiceSpies
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import crypto from "crypto"
|
||||
import {
|
||||
AuthenticationInput,
|
||||
AuthenticationResponse,
|
||||
@@ -16,7 +17,6 @@ type InjectedDependencies = {
|
||||
|
||||
interface LocalServiceConfig extends GithubAuthProviderOptions {}
|
||||
|
||||
// TODO: Add state param that is stored in Redis, to prevent CSRF attacks
|
||||
export class GithubAuthService extends AbstractAuthModuleProvider {
|
||||
static identifier = "github"
|
||||
static DISPLAY_NAME = "Github Authentication"
|
||||
@@ -56,37 +56,53 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
|
||||
}
|
||||
|
||||
async authenticate(
|
||||
req: AuthenticationInput
|
||||
req: AuthenticationInput,
|
||||
authIdentityService: AuthIdentityProviderService
|
||||
): Promise<AuthenticationResponse> {
|
||||
if (req.query?.error) {
|
||||
const query: Record<string, string> = req.query ?? {}
|
||||
const body: Record<string, string> = req.body ?? {}
|
||||
|
||||
if (query.error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
|
||||
error: `${query.error_description}, read more at: ${query.error_uri}`,
|
||||
}
|
||||
}
|
||||
|
||||
return this.getRedirect(this.config_)
|
||||
const stateKey = crypto.randomBytes(32).toString("hex")
|
||||
const state = {
|
||||
callback_url: body?.callback_url ?? this.config_.callbackUrl,
|
||||
}
|
||||
|
||||
await authIdentityService.setState(stateKey, state)
|
||||
return this.getRedirect(this.config_.clientId, state.callback_url, stateKey)
|
||||
}
|
||||
|
||||
async validateCallback(
|
||||
req: AuthenticationInput,
|
||||
authIdentityService: AuthIdentityProviderService
|
||||
): Promise<AuthenticationResponse> {
|
||||
if (req.query && req.query.error) {
|
||||
const query: Record<string, string> = req.query ?? {}
|
||||
const body: Record<string, string> = req.body ?? {}
|
||||
|
||||
if (query.error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
|
||||
error: `${query.error_description}, read more at: ${query.error_uri}`,
|
||||
}
|
||||
}
|
||||
|
||||
const code = req.query?.code ?? req.body?.code
|
||||
const code = query?.code ?? body?.code
|
||||
if (!code) {
|
||||
return { success: false, error: "No code provided" }
|
||||
}
|
||||
|
||||
const params = `client_id=${this.config_.clientId}&client_secret=${
|
||||
this.config_.clientSecret
|
||||
}&code=${code}&redirect_uri=${encodeURIComponent(this.config_.callbackUrl)}`
|
||||
const state = await authIdentityService.getState(query?.state as string)
|
||||
if (!state) {
|
||||
return { success: false, error: "No state provided, or session expired" }
|
||||
}
|
||||
|
||||
const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}`
|
||||
|
||||
const exchangeTokenUrl = new URL(
|
||||
`https://github.com/login/oauth/access_token?${params}`
|
||||
@@ -192,18 +208,12 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
|
||||
}
|
||||
}
|
||||
|
||||
private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) {
|
||||
const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}`
|
||||
const clientIdParam = `client_id=${clientId}`
|
||||
const responseTypeParam = "response_type=code"
|
||||
|
||||
const authUrl = new URL(
|
||||
`https://github.com/login/oauth/authorize?${[
|
||||
redirectUrlParam,
|
||||
clientIdParam,
|
||||
responseTypeParam,
|
||||
].join("&")}`
|
||||
)
|
||||
private getRedirect(clientId: string, callbackUrl: string, stateKey: string) {
|
||||
const authUrl = new URL(`https://github.com/login/oauth/authorize`)
|
||||
authUrl.searchParams.set("redirect_uri", callbackUrl)
|
||||
authUrl.searchParams.set("client_id", clientId)
|
||||
authUrl.searchParams.set("response_type", "code")
|
||||
authUrl.searchParams.set("state", stateKey)
|
||||
|
||||
return { success: true, location: authUrl.toString() }
|
||||
}
|
||||
|
||||
@@ -28,6 +28,22 @@ const encodedIdToken = generateJwtToken(sampleIdPayload, {
|
||||
})
|
||||
|
||||
const baseUrl = "https://someurl.com"
|
||||
const callbackUrl = encodeURIComponent(
|
||||
"https://someurl.com/auth/google/callback"
|
||||
)
|
||||
|
||||
let state = {}
|
||||
const defaultSpies = {
|
||||
retrieve: jest.fn(),
|
||||
create: jest.fn(),
|
||||
update: jest.fn(),
|
||||
setState: jest.fn().mockImplementation((key, value) => {
|
||||
state[key] = value
|
||||
}),
|
||||
getState: jest.fn().mockImplementation((key) => {
|
||||
return Promise.resolve(state[key])
|
||||
}),
|
||||
}
|
||||
|
||||
// This is just a network-layer mocking, it doesn't start an actual server
|
||||
const server = setupServer(
|
||||
@@ -37,7 +53,7 @@ const server = setupServer(
|
||||
const url = request.url
|
||||
if (
|
||||
url ===
|
||||
"https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&grant_type=authorization_code"
|
||||
`https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}&grant_type=authorization_code`
|
||||
) {
|
||||
return new HttpResponse(null, {
|
||||
status: 401,
|
||||
@@ -47,7 +63,7 @@ const server = setupServer(
|
||||
|
||||
if (
|
||||
url ===
|
||||
"https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&grant_type=authorization_code"
|
||||
`https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}&grant_type=authorization_code`
|
||||
) {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
@@ -90,6 +106,7 @@ describe("Google auth provider", () => {
|
||||
afterEach(() => {
|
||||
server.resetHandlers()
|
||||
jest.restoreAllMocks()
|
||||
state = {}
|
||||
})
|
||||
|
||||
afterAll(() => server.close())
|
||||
@@ -109,11 +126,27 @@ describe("Google auth provider", () => {
|
||||
})
|
||||
|
||||
it("returns a redirect URL on authenticate", async () => {
|
||||
const res = await googleService.authenticate({})
|
||||
const res = await googleService.authenticate({}, defaultSpies)
|
||||
expect(res).toEqual({
|
||||
success: true,
|
||||
location:
|
||||
"https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&client_id=test&response_type=code&scope=email+profile+openid",
|
||||
location: `https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=${callbackUrl}&client_id=test&response_type=code&scope=email+profile+openid&state=${
|
||||
Object.keys(state)[0]
|
||||
}`,
|
||||
})
|
||||
})
|
||||
|
||||
it("returns a custom redirect_uri on authenticate", async () => {
|
||||
const res = await googleService.authenticate(
|
||||
{
|
||||
body: { callback_url: "https://someotherurl.com" },
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: true,
|
||||
location: `https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&scope=email+profile+openid&state=${
|
||||
Object.keys(state)[0]
|
||||
}`,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -122,7 +155,7 @@ describe("Google auth provider", () => {
|
||||
{
|
||||
query: {},
|
||||
},
|
||||
{} as any
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
@@ -130,14 +163,52 @@ describe("Google auth provider", () => {
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return an error on missing state", async () => {
|
||||
const res = await googleService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
},
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
error: "No state provided, or session expired",
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return an error on expired/invalid state", async () => {
|
||||
const res = await googleService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
defaultSpies
|
||||
)
|
||||
expect(res).toEqual({
|
||||
success: false,
|
||||
error: "No state provided, or session expired",
|
||||
})
|
||||
})
|
||||
|
||||
it("validate callback should return on a missing access token for code", async () => {
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
|
||||
const res = await googleService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "invalid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
{} as any
|
||||
defaultSpies
|
||||
)
|
||||
|
||||
expect(res).toEqual({
|
||||
@@ -148,6 +219,7 @@ describe("Google auth provider", () => {
|
||||
|
||||
it("validate callback should return successfully on a correct code for a new user", async () => {
|
||||
const authServiceSpies = {
|
||||
...defaultSpies,
|
||||
retrieve: jest.fn().mockImplementation(() => {
|
||||
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found")
|
||||
}),
|
||||
@@ -166,10 +238,17 @@ describe("Google auth provider", () => {
|
||||
}),
|
||||
}
|
||||
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
|
||||
const res = await googleService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
authServiceSpies
|
||||
@@ -190,6 +269,7 @@ describe("Google auth provider", () => {
|
||||
|
||||
it("validate callback should return successfully on a correct code for an existing user", async () => {
|
||||
const authServiceSpies = {
|
||||
...defaultSpies,
|
||||
retrieve: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
provider_identities: [
|
||||
@@ -208,10 +288,17 @@ describe("Google auth provider", () => {
|
||||
}),
|
||||
}
|
||||
|
||||
state = {
|
||||
somekey: {
|
||||
callback_url: callbackUrl,
|
||||
},
|
||||
}
|
||||
|
||||
const res = await googleService.validateCallback(
|
||||
{
|
||||
query: {
|
||||
code: "valid-code",
|
||||
state: "somekey",
|
||||
},
|
||||
},
|
||||
authServiceSpies
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import crypto from "crypto"
|
||||
import {
|
||||
AuthenticationInput,
|
||||
AuthenticationResponse,
|
||||
@@ -16,8 +17,6 @@ type InjectedDependencies = {
|
||||
}
|
||||
|
||||
interface LocalServiceConfig extends GoogleAuthProviderOptions {}
|
||||
|
||||
// TODO: Add state param that is stored in Redis, to prevent CSRF attacks
|
||||
export class GoogleAuthService extends AbstractAuthModuleProvider {
|
||||
static identifier = "google"
|
||||
static DISPLAY_NAME = "Google Authentication"
|
||||
@@ -57,39 +56,53 @@ export class GoogleAuthService extends AbstractAuthModuleProvider {
|
||||
}
|
||||
|
||||
async authenticate(
|
||||
req: AuthenticationInput
|
||||
req: AuthenticationInput,
|
||||
authIdentityService: AuthIdentityProviderService
|
||||
): Promise<AuthenticationResponse> {
|
||||
if (req.query?.error) {
|
||||
const query: Record<string, string> = req.query ?? {}
|
||||
const body: Record<string, string> = req.body ?? {}
|
||||
|
||||
if (query.error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
|
||||
error: `${query.error_description}, read more at: ${query.error_uri}`,
|
||||
}
|
||||
}
|
||||
|
||||
return this.getRedirect(this.config_)
|
||||
const stateKey = crypto.randomBytes(32).toString("hex")
|
||||
const state = {
|
||||
callback_url: body?.callback_url ?? this.config_.callbackUrl,
|
||||
}
|
||||
|
||||
await authIdentityService.setState(stateKey, state)
|
||||
return this.getRedirect(this.config_.clientId, state.callback_url, stateKey)
|
||||
}
|
||||
|
||||
async validateCallback(
|
||||
req: AuthenticationInput,
|
||||
authIdentityService: AuthIdentityProviderService
|
||||
): Promise<AuthenticationResponse> {
|
||||
if (req.query && req.query.error) {
|
||||
const query: Record<string, string> = req.query ?? {}
|
||||
const body: Record<string, string> = req.body ?? {}
|
||||
|
||||
if (query.error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
|
||||
error: `${query.error_description}, read more at: ${query.error_uri}`,
|
||||
}
|
||||
}
|
||||
|
||||
const code = req.query?.code ?? req.body?.code
|
||||
const code = query?.code ?? body?.code
|
||||
if (!code) {
|
||||
return { success: false, error: "No code provided" }
|
||||
}
|
||||
|
||||
const params = `client_id=${this.config_.clientId}&client_secret=${
|
||||
this.config_.clientSecret
|
||||
}&code=${code}&redirect_uri=${encodeURIComponent(
|
||||
this.config_.callbackUrl
|
||||
)}&grant_type=authorization_code`
|
||||
const state = await authIdentityService.getState(query?.state as string)
|
||||
if (!state) {
|
||||
return { success: false, error: "No state provided, or session expired" }
|
||||
}
|
||||
|
||||
const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}&grant_type=authorization_code`
|
||||
const exchangeTokenUrl = new URL(
|
||||
`https://oauth2.googleapis.com/token?${params}`
|
||||
)
|
||||
@@ -175,20 +188,13 @@ export class GoogleAuthService extends AbstractAuthModuleProvider {
|
||||
}
|
||||
}
|
||||
|
||||
private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) {
|
||||
const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}`
|
||||
const clientIdParam = `client_id=${clientId}`
|
||||
const responseTypeParam = "response_type=code"
|
||||
const scopeParam = "scope=email+profile+openid"
|
||||
|
||||
const authUrl = new URL(
|
||||
`https://accounts.google.com/o/oauth2/v2/auth?${[
|
||||
redirectUrlParam,
|
||||
clientIdParam,
|
||||
responseTypeParam,
|
||||
scopeParam,
|
||||
].join("&")}`
|
||||
)
|
||||
private getRedirect(clientId: string, callbackUrl: string, stateKey: string) {
|
||||
const authUrl = new URL(`https://accounts.google.com/o/oauth2/v2/auth`)
|
||||
authUrl.searchParams.set("redirect_uri", callbackUrl)
|
||||
authUrl.searchParams.set("client_id", clientId)
|
||||
authUrl.searchParams.set("response_type", "code")
|
||||
authUrl.searchParams.set("scope", "email profile openid")
|
||||
authUrl.searchParams.set("state", stateKey)
|
||||
|
||||
return { success: true, location: authUrl.toString() }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user