Fix: generalize OAuth flow

This commit is contained in:
Manoj K 2025-07-23 10:27:52 +05:30 committed by Harshith Mullapudi
parent 1916a58c1c
commit 6decfbcf1b
6 changed files with 378 additions and 68 deletions

View File

@ -89,17 +89,25 @@ export const action = async ({ request }: ActionFunctionArgs) => {
// Validate scopes // Validate scopes
const validScopes = [ const validScopes = [
// Authentication scopes (Google-style) // Authentication scopes (Google-style)
'profile', 'email', 'openid', "profile",
"email",
"openid",
// Integration scope // Integration scope
'integration' "integration",
]; ];
const requestedScopes = Array.isArray(allowedScopes) ? allowedScopes : [allowedScopes || 'read']; const requestedScopes = Array.isArray(allowedScopes)
const invalidScopes = requestedScopes.filter(scope => !validScopes.includes(scope)); ? allowedScopes
: [allowedScopes || "read"];
const invalidScopes = requestedScopes.filter(
(scope) => !validScopes.includes(scope),
);
if (invalidScopes.length > 0) { if (invalidScopes.length > 0) {
return json( return json(
{ error: `Invalid scopes: ${invalidScopes.join(', ')}. Valid scopes are: ${validScopes.join(', ')}` }, {
error: `Invalid scopes: ${invalidScopes.join(", ")}. Valid scopes are: ${validScopes.join(", ")}`,
},
{ status: 400 }, { status: 400 },
); );
} }
@ -114,6 +122,10 @@ export const action = async ({ request }: ActionFunctionArgs) => {
return json({ error: "No workspace found" }, { status: 404 }); return json({ error: "No workspace found" }, { status: 404 });
} }
if (!userRecord?.admin) {
return json({ error: "No access to create OAuth app" }, { status: 404 });
}
// Generate client credentials // Generate client credentials
const clientId = crypto.randomUUID(); const clientId = crypto.randomUUID();
const clientSecret = crypto.randomBytes(32).toString("hex"); const clientSecret = crypto.randomBytes(32).toString("hex");

View File

@ -10,26 +10,12 @@ import {
OAuth2Errors, OAuth2Errors,
type OAuth2AuthorizeRequest, type OAuth2AuthorizeRequest,
} from "~/services/oauth2.server"; } from "~/services/oauth2.server";
import { getIntegrationAccounts } from "~/services/integrationAccount.server";
import { Button } from "~/components/ui/button"; import { Button } from "~/components/ui/button";
import { Card, CardContent } from "~/components/ui/card"; import { Card, CardContent } from "~/components/ui/card";
import { Arrows } from "~/components/icons"; import { Arrows } from "~/components/icons";
import Logo from "~/components/logo/logo"; import Logo from "~/components/logo/logo";
import { AlignLeft, LayoutGrid, Pen, User, Mail, Shield, Database } from "lucide-react"; import { AlignLeft, LayoutGrid, Pen, User, Mail, Shield, Database } from "lucide-react";
// Helper function to convert integration definition IDs to account IDs
async function convertDefIdsToAccountIds(defIds: string[], userId: string): Promise<string[]> {
const integrationAccounts = await getIntegrationAccounts(userId);
const defToAccountMap = new Map(
integrationAccounts
.filter(acc => acc.isActive)
.map(acc => [acc.integrationDefinitionId, acc.id])
);
return defIds
.map(defId => defToAccountMap.get(defId))
.filter(Boolean) as string[];
}
export const loader = async ({ request }: LoaderFunctionArgs) => { export const loader = async ({ request }: LoaderFunctionArgs) => {
// Check if user is authenticated // Check if user is authenticated

View File

@ -0,0 +1,24 @@
import { type LoaderFunctionArgs, json } from "@remix-run/node";
import { oauth2Service } from "~/services/oauth2.server";
export const loader = async ({ request }: LoaderFunctionArgs) => {
const url = new URL(request.url);
const idToken = url.searchParams.get("id_token");
if (!idToken) {
return json(
{ error: "invalid_request", error_description: "Missing id_token parameter" },
{ status: 400 }
);
}
try {
const userInfo = await oauth2Service.getUserInfoFromIdToken(idToken);
return json(userInfo);
} catch (error) {
return json(
{ error: "invalid_token", error_description: "Invalid or expired ID token" },
{ status: 401 }
);
}
};

View File

@ -1,5 +1,7 @@
import { PrismaClient } from "@prisma/client"; import { PrismaClient } from "@prisma/client";
import crypto from "crypto"; import crypto from "crypto";
import { env } from "~/env.server";
import { type JWTPayload, jwtVerify, SignJWT } from "jose";
const prisma = new PrismaClient(); const prisma = new PrismaClient();
@ -28,6 +30,7 @@ export interface OAuth2TokenResponse {
expires_in: number; expires_in: number;
refresh_token?: string; refresh_token?: string;
scope?: string; scope?: string;
id_token?: string;
} }
export interface OAuth2ErrorResponse { export interface OAuth2ErrorResponse {
@ -37,6 +40,19 @@ export interface OAuth2ErrorResponse {
state?: string; state?: string;
} }
export interface IDTokenClaims {
iss: string; // Issuer
aud: string; // Audience (client_id)
sub: string; // Subject (user ID)
exp: number; // Expiration time
iat: number; // Issued at
email?: string;
email_verified?: boolean;
name?: string;
picture?: string;
installation_id?: string;
}
// OAuth2 Error types // OAuth2 Error types
export const OAuth2Errors = { export const OAuth2Errors = {
INVALID_REQUEST: "invalid_request", INVALID_REQUEST: "invalid_request",
@ -52,9 +68,149 @@ export const OAuth2Errors = {
} as const; } as const;
export class OAuth2Service { export class OAuth2Service {
// Generate secure random string private generateAccessToken(params: {
private generateSecureToken(length: number = 32): string { userId: string;
return crypto.randomBytes(length).toString("hex"); clientId: string;
workspaceId: string;
scope?: string;
}): string {
const payload = {
type: "access_token",
user_id: params.userId,
client_id: params.clientId,
workspace_id: params.workspaceId,
scope: params.scope,
jti: crypto.randomBytes(16).toString("hex"),
iat: Math.floor(Date.now() / 1000),
};
const encoded = Buffer.from(JSON.stringify(payload)).toString("base64url");
return `at_${encoded}`;
}
private generateRefreshToken(params: {
userId: string;
clientId: string;
workspaceId: string;
}): string {
const payload = {
type: "refresh_token",
user_id: params.userId,
client_id: params.clientId,
workspace_id: params.workspaceId,
jti: crypto.randomBytes(16).toString("hex"),
iat: Math.floor(Date.now() / 1000),
};
const encoded = Buffer.from(JSON.stringify(payload)).toString("base64url");
return `rt_${encoded}`;
}
private generateAuthorizationCode(params: {
clientId: string;
userId: string;
workspaceId: string;
}): string {
const payload = {
type: "authorization_code",
client_id: params.clientId,
user_id: params.userId,
workspace_id: params.workspaceId,
jti: crypto.randomBytes(12).toString("hex"),
iat: Math.floor(Date.now() / 1000),
};
const encoded = Buffer.from(JSON.stringify(payload)).toString("base64url");
return `ac_${encoded}`;
}
private async generateIdToken(params: {
userId: string;
clientId: string;
workspaceId: string;
email?: string;
name?: string;
avatarUrl?: string;
installationId?: string;
scopes?: string[];
}): Promise<string> {
const now = Math.floor(Date.now() / 1000);
const exp = now + 3600; // 1 hour
const claims: IDTokenClaims = {
iss: env.LOGIN_ORIGIN,
aud: params.clientId,
sub: params.userId,
exp,
iat: now,
};
// Add optional claims based on scopes
if (params.scopes?.includes("email") && params.email) {
claims.email = params.email;
claims.email_verified = true; // Assuming all CORE emails are verified
}
if (params.scopes?.includes("profile")) {
if (params.name) claims.name = params.name;
if (params.avatarUrl) claims.picture = params.avatarUrl;
}
if (params.installationId) {
claims.installation_id = params.installationId;
}
// Sign JWT with secret
const secret = new TextEncoder().encode(env.SESSION_SECRET);
return await new SignJWT(claims as JWTPayload)
.setProtectedHeader({ alg: "HS256" })
.sign(secret);
}
private extractTokenPayload(token: string): any {
try {
const parts = token.split("_");
if (parts.length !== 2) return null;
const encoded = parts[1];
const decoded = Buffer.from(encoded, "base64url").toString();
return JSON.parse(decoded);
} catch {
return null;
}
}
private validateTokenFormat(
token: string,
expectedType: "access_token" | "refresh_token" | "authorization_code",
): any {
try {
const prefixMap = {
access_token: "at_",
refresh_token: "rt_",
authorization_code: "ac_",
};
const expectedPrefix = prefixMap[expectedType];
if (!token.startsWith(expectedPrefix)) {
return null;
}
const payload = this.extractTokenPayload(token);
if (!payload || payload.type !== expectedType) {
return null;
}
return payload;
} catch {
return null;
}
} }
// Validate OAuth2 client // Validate OAuth2 client
@ -117,6 +273,16 @@ export class OAuth2Service {
return requestedScopeArray.every((scope) => allowedScopes.includes(scope)); return requestedScopeArray.every((scope) => allowedScopes.includes(scope));
} }
async verifyIdToken(idToken: string): Promise<IDTokenClaims> {
try {
const secret = new TextEncoder().encode(env.SESSION_SECRET);
const { payload } = await jwtVerify(idToken, secret);
return payload as IDTokenClaims;
} catch (error) {
throw new Error("Invalid ID token");
}
}
// Determine scope type for routing (simplified) // Determine scope type for routing (simplified)
getScopeType(scope: string): "auth" | "integration" | "mixed" { getScopeType(scope: string): "auth" | "integration" | "mixed" {
const scopes = scope.split(",").map((s) => s.trim()); const scopes = scope.split(",").map((s) => s.trim());
@ -178,35 +344,62 @@ export class OAuth2Service {
codeChallenge?: string; codeChallenge?: string;
codeChallengeMethod?: string; codeChallengeMethod?: string;
}): Promise<string> { }): Promise<string> {
const code = this.generateSecureToken(32); const code = this.generateAuthorizationCode(params);
const expiresAt = new Date(Date.now() + 10 * 60 * 1000); // 10 minutes const expiresAt = new Date(Date.now() + 10 * 60 * 1000); // 10 minutes
// Find the client to get the internal database ID // Find the client to get the internal database ID
const client = await prisma.oAuthClient.findUnique({ const client = await prisma.oAuthClient.findUnique({
where: { clientId: params.clientId }, where: { clientId: params.clientId },
select: { id: true },
}); });
if (!client) { if (!client) {
throw new Error(OAuth2Errors.INVALID_CLIENT); throw new Error(OAuth2Errors.INVALID_CLIENT);
} }
await prisma.oAuthAuthorizationCode.create({ try {
data: { await prisma.oAuthAuthorizationCode.create({
data: {
code,
clientId: client.id,
userId: params.userId,
redirectUri: params.redirectUri,
scope: params.scope,
state: params.state,
codeChallenge: params.codeChallenge,
codeChallengeMethod: params.codeChallengeMethod,
workspaceId: params.workspaceId,
expiresAt,
},
});
} catch (error) {
throw new Error("Failed to create authorization code");
}
return code;
}
async validateAuthorizationCode(code: string): Promise<any> {
const tokenPayload = this.validateTokenFormat(code, "authorization_code");
if (!tokenPayload) {
throw new Error("Invalid or expired token");
}
const authorizationCode = await prisma.oAuthAuthorizationCode.findFirst({
where: {
code, code,
clientId: client.id, // Use internal database ID workspaceId: tokenPayload.workspace_id,
userId: params.userId, expiresAt: { gt: new Date() },
redirectUri: params.redirectUri, },
scope: params.scope, include: {
state: params.state, client: true,
codeChallenge: params.codeChallenge, user: true,
codeChallengeMethod: params.codeChallengeMethod,
workspaceId: params.workspaceId,
expiresAt,
}, },
}); });
return code; if (!authorizationCode) {
throw new Error("Invalid or expired token");
}
return authorizationCode;
} }
// Exchange authorization code for tokens // Exchange authorization code for tokens
@ -219,27 +412,13 @@ export class OAuth2Service {
// Find the client first to get the internal database ID // Find the client first to get the internal database ID
const client = await prisma.oAuthClient.findUnique({ const client = await prisma.oAuthClient.findUnique({
where: { clientId: params.clientId }, where: { clientId: params.clientId },
select: { id: true },
}); });
if (!client) { if (!client) {
throw new Error(OAuth2Errors.INVALID_CLIENT); throw new Error(OAuth2Errors.INVALID_CLIENT);
} }
// Find and validate authorization code const authCode = await this.validateAuthorizationCode(params.code);
const authCode = await prisma.oAuthAuthorizationCode.findFirst({
where: {
code: params.code,
clientId: client.id, // Use internal database ID
redirectUri: params.redirectUri,
used: false,
expiresAt: { gt: new Date() },
},
include: {
client: true,
user: true,
},
});
if (!authCode) { if (!authCode) {
throw new Error(OAuth2Errors.INVALID_GRANT); throw new Error(OAuth2Errors.INVALID_GRANT);
@ -268,9 +447,20 @@ export class OAuth2Service {
}); });
// Generate access token // Generate access token
const accessToken = this.generateSecureToken(64); const accessToken = this.generateAccessToken({
const refreshToken = this.generateSecureToken(64); userId: authCode.userId,
const expiresIn = 3600; // 1 hour clientId: client.clientId,
workspaceId: authCode.workspaceId,
scope: authCode.scope || undefined,
});
const refreshToken = this.generateRefreshToken({
userId: authCode.userId,
clientId: client.clientId,
workspaceId: authCode.workspaceId,
});
const expiresIn = 86400; // 1 day
const accessTokenExpiresAt = new Date(Date.now() + expiresIn * 1000); const accessTokenExpiresAt = new Date(Date.now() + expiresIn * 1000);
const refreshTokenExpiresAt = new Date( const refreshTokenExpiresAt = new Date(
Date.now() + 30 * 24 * 60 * 60 * 1000, Date.now() + 30 * 24 * 60 * 60 * 1000,
@ -280,7 +470,7 @@ export class OAuth2Service {
await prisma.oAuthAccessToken.create({ await prisma.oAuthAccessToken.create({
data: { data: {
token: accessToken, token: accessToken,
clientId: client.id, // Use internal database ID clientId: client.id,
userId: authCode.userId, userId: authCode.userId,
scope: authCode.scope, scope: authCode.scope,
expiresAt: accessTokenExpiresAt, expiresAt: accessTokenExpiresAt,
@ -291,7 +481,7 @@ export class OAuth2Service {
await prisma.oAuthRefreshToken.create({ await prisma.oAuthRefreshToken.create({
data: { data: {
token: refreshToken, token: refreshToken,
clientId: client.id, // Use internal database ID clientId: client.id,
userId: authCode.userId, userId: authCode.userId,
scope: authCode.scope, scope: authCode.scope,
expiresAt: refreshTokenExpiresAt, expiresAt: refreshTokenExpiresAt,
@ -299,8 +489,21 @@ export class OAuth2Service {
}, },
}); });
await prisma.oAuthClientInstallation.create({ const installation = await prisma.oAuthClientInstallation.upsert({
data: { where: {
oauthClientId_workspaceId: {
oauthClientId: client.id,
workspaceId: authCode.workspaceId,
},
},
update: {
oauthClientId: client.id,
workspaceId: authCode.workspaceId,
installedById: authCode.userId,
isActive: true,
grantedScopes: authCode.scope,
},
create: {
oauthClientId: client.id, oauthClientId: client.id,
workspaceId: authCode.workspaceId, workspaceId: authCode.workspaceId,
installedById: authCode.userId, installedById: authCode.userId,
@ -309,22 +512,54 @@ export class OAuth2Service {
}, },
}); });
const idToken = await this.generateIdToken({
userId: authCode.userId,
clientId: client.clientId,
workspaceId: authCode.workspaceId,
email: authCode.user.email,
name: authCode.user.name || null,
avatarUrl: authCode.user.avatarUrl || null,
installationId: installation.id,
scopes: authCode.scope?.split(","),
});
return { return {
access_token: accessToken, access_token: accessToken,
token_type: "Bearer", token_type: "Bearer",
expires_in: expiresIn, expires_in: expiresIn,
refresh_token: refreshToken, refresh_token: refreshToken,
scope: authCode.scope || undefined, scope: authCode.scope || undefined,
id_token: idToken,
};
}
async getUserInfoFromIdToken(idToken: string): Promise<any> {
const claims = await this.verifyIdToken(idToken);
return {
sub: claims.sub,
email: claims.email,
email_verified: claims.email_verified,
name: claims.name,
picture: claims.picture,
installation_id: claims.installation_id,
}; };
} }
// Validate access token // Validate access token
async validateAccessToken(token: string, scopes?: string[]): Promise<any> { async validateAccessToken(token: string, scopes?: string[]): Promise<any> {
const tokenPayload = this.validateTokenFormat(token, "access_token");
if (!tokenPayload) {
throw new Error("Invalid or expired token");
}
const accessToken = await prisma.oAuthAccessToken.findFirst({ const accessToken = await prisma.oAuthAccessToken.findFirst({
where: { where: {
token, token,
revoked: false, revoked: false,
expiresAt: { gt: new Date() }, expiresAt: { gt: new Date() },
userId: tokenPayload.user_id,
workspaceId: tokenPayload.workspace_id,
...(scopes ? { scope: { contains: scopes.join(",") } } : {}), ...(scopes ? { scope: { contains: scopes.join(",") } } : {}),
}, },
include: { include: {
@ -350,10 +585,32 @@ export class OAuth2Service {
name: accessToken.user.name, name: accessToken.user.name,
display_name: accessToken.user.displayName, display_name: accessToken.user.displayName,
avatar_url: accessToken.user.avatarUrl, avatar_url: accessToken.user.avatarUrl,
email_verified: true, // Assuming email is verified if user exists email_verified: true,
}; };
} }
async validateRefreshToken(token: string): Promise<any> {
const tokenPayload = await this.validateTokenFormat(token, "refresh_token");
if (!tokenPayload) {
throw new Error("Invalid or expired token");
}
const refreshToken = await prisma.oAuthRefreshToken.findFirst({
where: {
token,
clientId: tokenPayload.client_id,
revoked: false,
expiresAt: { gt: new Date() },
},
});
if (!refreshToken) {
throw new Error("Invalid or expired token");
}
return refreshToken;
}
// Refresh access token // Refresh access token
async refreshAccessToken( async refreshAccessToken(
refreshToken: string, refreshToken: string,
@ -362,7 +619,6 @@ export class OAuth2Service {
// Find the client first to get the internal database ID // Find the client first to get the internal database ID
const client = await prisma.oAuthClient.findUnique({ const client = await prisma.oAuthClient.findUnique({
where: { clientId }, where: { clientId },
select: { id: true },
}); });
if (!client) { if (!client) {
@ -372,7 +628,7 @@ export class OAuth2Service {
const storedRefreshToken = await prisma.oAuthRefreshToken.findFirst({ const storedRefreshToken = await prisma.oAuthRefreshToken.findFirst({
where: { where: {
token: refreshToken, token: refreshToken,
clientId: client.id, // Use internal database ID clientId: client.id,
revoked: false, revoked: false,
expiresAt: { gt: new Date() }, expiresAt: { gt: new Date() },
}, },
@ -386,15 +642,40 @@ export class OAuth2Service {
throw new Error(OAuth2Errors.INVALID_GRANT); throw new Error(OAuth2Errors.INVALID_GRANT);
} }
const newRefreshToken = this.generateRefreshToken({
userId: storedRefreshToken.userId,
clientId: client.clientId,
workspaceId: storedRefreshToken.workspaceId,
});
// Generate new access token // Generate new access token
const accessToken = this.generateSecureToken(64); const accessToken = this.generateAccessToken({
userId: storedRefreshToken.userId,
clientId: client.clientId,
workspaceId: storedRefreshToken.workspaceId,
scope: storedRefreshToken.scope || undefined,
});
const expiresIn = 86400; // 1 day const expiresIn = 86400; // 1 day
const accessTokenExpiresAt = new Date(Date.now() + expiresIn * 1000); const accessTokenExpiresAt = new Date(Date.now() + expiresIn * 1000);
const newRefreshTokenExpiresAt = new Date(
Date.now() + 30 * 24 * 60 * 60 * 1000,
);
await prisma.oAuthRefreshToken.create({
data: {
token: newRefreshToken,
clientId: client.id,
userId: storedRefreshToken.userId,
scope: storedRefreshToken.scope,
expiresAt: newRefreshTokenExpiresAt,
workspaceId: storedRefreshToken.workspaceId,
},
});
await prisma.oAuthAccessToken.create({ await prisma.oAuthAccessToken.create({
data: { data: {
token: accessToken, token: accessToken,
clientId: client.id, // Use internal database ID clientId: client.id,
userId: storedRefreshToken.userId, userId: storedRefreshToken.userId,
scope: storedRefreshToken.scope, scope: storedRefreshToken.scope,
expiresAt: accessTokenExpiresAt, expiresAt: accessTokenExpiresAt,
@ -406,6 +687,7 @@ export class OAuth2Service {
access_token: accessToken, access_token: accessToken,
token_type: "Bearer", token_type: "Bearer",
expires_in: expiresIn, expires_in: expiresIn,
refresh_token: newRefreshToken,
scope: storedRefreshToken.scope || undefined, scope: storedRefreshToken.scope || undefined,
}; };
} }

View File

@ -99,8 +99,9 @@ export const integrationWebhookTask = task({
const targets: WebhookTarget[] = oauthClients const targets: WebhookTarget[] = oauthClients
.filter((client) => client.oauthClient?.webhookUrl) .filter((client) => client.oauthClient?.webhookUrl)
.map((client) => ({ .map((client) => ({
url: `${client.oauthClient?.webhookUrl}/${payload.userId}`, url: `${client.oauthClient?.webhookUrl}`,
secret: client.oauthClient?.webhookSecret, secret: client.oauthClient?.webhookSecret,
accountId: client.id,
})); }));
// Use common delivery function // Use common delivery function

View File

@ -14,6 +14,7 @@ export interface WebhookTarget {
url: string; url: string;
secret?: string | null; secret?: string | null;
headers?: Record<string, string>; headers?: Record<string, string>;
accountId?: string;
} }
// Delivery result // Delivery result
@ -51,7 +52,10 @@ export async function deliverWebhook(params: WebhookDeliveryParams): Promise<{
userAgent = "Core-Webhooks/1.0", userAgent = "Core-Webhooks/1.0",
eventType, eventType,
} = params; } = params;
const payloadString = JSON.stringify(payload); const payloadString = JSON.stringify({
...payload,
accountId: payload.accountId,
});
const deliveryResults: DeliveryResult[] = []; const deliveryResults: DeliveryResult[] = [];
logger.log(`Delivering ${eventType} webhook to ${targets.length} targets`); logger.log(`Delivering ${eventType} webhook to ${targets.length} targets`);
@ -154,10 +158,11 @@ export async function deliverWebhook(params: WebhookDeliveryParams): Promise<{
* Helper function to prepare webhook targets from basic URL/secret pairs * Helper function to prepare webhook targets from basic URL/secret pairs
*/ */
export function prepareWebhookTargets( export function prepareWebhookTargets(
webhooks: Array<{ url: string; secret?: string | null }>, webhooks: Array<{ url: string; secret?: string | null; id: string }>,
): WebhookTarget[] { ): WebhookTarget[] {
return webhooks.map((webhook) => ({ return webhooks.map((webhook) => ({
url: webhook.url, url: webhook.url,
secret: webhook.secret, secret: webhook.secret,
accountId: webhook.id,
})); }));
} }