Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pkg/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// Bypass cross-origin protection: this server uses bearer tokens (not
// cookies), so Sec-Fetch-Site CSRF checks are unnecessary. See PR #2359.
crossOriginProtection := http.NewCrossOriginProtection()
crossOriginProtection.AddInsecureBypassPattern("/")

mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server {
return ghServer
}, &mcp.StreamableHTTPOptions{
Stateless: true,
Stateless: true,
CrossOriginProtection: crossOriginProtection,
})

mcpHandler.ServeHTTP(w, r)
Expand Down
70 changes: 70 additions & 0 deletions pkg/http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http/httptest"
"slices"
"sort"
"strings"
"testing"

ghcontext "github.com/github/github-mcp-server/pkg/context"
Expand Down Expand Up @@ -660,3 +661,72 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo
ctx := context.Background()
return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx)
}

func TestCrossOriginProtection(t *testing.T) {
jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}`

apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com")
require.NoError(t, err)

handler := NewHTTPMcpHandler(
context.Background(),
&ServerConfig{
Version: "test",
},
nil,
translations.NullTranslationHelper,
slog.Default(),
apiHost,
WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) {
return inventory.NewBuilder().Build()
}),
WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) {
return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil
}),
WithScopeFetcher(allScopesFetcher{}),
)

r := chi.NewRouter()
handler.RegisterMiddleware(r)
handler.RegisterRoutes(r)

tests := []struct {
name string
secFetchSite string
origin string
}{
{
name: "cross-site request with bearer token succeeds",
secFetchSite: "cross-site",
origin: "https://example.com",
},
{
name: "same-origin request succeeds",
secFetchSite: "same-origin",
},
{
name: "native client without Sec-Fetch-Site succeeds",
secFetchSite: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz")
if tt.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
}
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}

rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code, "unexpected status code; body: %s", rr.Body.String())
})
}
}
43 changes: 43 additions & 0 deletions pkg/http/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package middleware

import (
"net/http"
"strings"

"github.com/github/github-mcp-server/pkg/http/headers"
)

// SetCorsHeaders is middleware that sets CORS headers to allow browser-based
// MCP clients to connect from any origin. This is safe because the server
// authenticates via bearer tokens (not cookies), so cross-origin requests
// cannot exploit ambient credentials.
func SetCorsHeaders(h http.Handler) http.Handler {
allowHeaders := strings.Join([]string{
"Content-Type",
"Mcp-Session-Id",
"Mcp-Protocol-Version",
"Last-Event-ID",
headers.AuthorizationHeader,
headers.MCPReadOnlyHeader,
headers.MCPToolsetsHeader,
headers.MCPToolsHeader,
headers.MCPExcludeToolsHeader,
headers.MCPFeaturesHeader,
headers.MCPLockdownHeader,
headers.MCPInsidersHeader,
}, ", ")

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Max-Age", "86400")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate")
w.Header().Set("Access-Control-Allow-Headers", allowHeaders)

if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
45 changes: 45 additions & 0 deletions pkg/http/middleware/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package middleware_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/github/github-mcp-server/pkg/http/middleware"
"github.com/stretchr/testify/assert"
)

func TestSetCorsHeaders(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := middleware.SetCorsHeaders(inner)

t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate")
})

t.Run("POST request includes CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
})
}
3 changes: 3 additions & 0 deletions pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/http/middleware"
"github.com/github/github-mcp-server/pkg/http/oauth"
"github.com/github/github-mcp-server/pkg/inventory"
"github.com/github/github-mcp-server/pkg/lockdown"
Expand Down Expand Up @@ -167,6 +168,8 @@ func RunHTTPServer(cfg ServerConfig) error {
}

r.Group(func(r chi.Router) {
r.Use(middleware.SetCorsHeaders)

// Register Middleware First, needs to be before route registration
handler.RegisterMiddleware(r)

Expand Down
Loading