From 151e926f611d3f51ef3ba440228218da8c5924b4 Mon Sep 17 00:00:00 2001 From: edmondfrank Date: Thu, 24 Apr 2025 09:53:43 +0800 Subject: [PATCH] feat: Wrap option handler to pass extra options for Gitee client feat: Gitee client supports cookies authorization feat: Gitee client supports customized API base Refactor: remove wrapoptionhandler in addtool function Signed-off-by: edmondfrank --- Makefile | 2 +- main.go | 28 ++- operations/enterprises/list_enterprises.go | 7 +- operations/groups/list_group.go | 6 +- .../issue_states/list_issue_type_states.go | 6 +- operations/issue_types/list_issue_types.go | 6 +- operations/issues/comment_issue.go | 6 +- operations/issues/create_issue.go | 5 +- operations/issues/get_issue_detail.go | 6 +- operations/issues/list_issue_comments.go | 6 +- operations/issues/list_issues.go | 7 +- operations/issues/update_issue.go | 6 +- operations/labels/list_labels.go | 6 +- operations/members/list_ent_members.go | 6 +- operations/programs/list_programs.go | 6 +- operations/pulls/comment_pull.go | 6 +- operations/pulls/create_pull.go | 6 +- operations/pulls/get_pull_detail.go | 6 +- operations/pulls/get_pull_diff.go | 6 +- operations/pulls/list_pull_comments.go | 6 +- operations/pulls/list_pulls.go | 6 +- operations/pulls/merge_pull.go | 6 +- operations/pulls/update_pull.go | 6 +- operations/repository/create_release.go | 6 +- operations/repository/create_repository.go | 10 +- operations/repository/list_releases.go | 6 +- operations/repository/list_repository.go | 5 +- .../scrum_sprints/create_scrum_sprint.go | 8 +- .../scrum_sprints/list_scrum_sprints.go | 6 +- .../scrum_versions/list_scrum_versions.go | 6 +- operations/user/get_user_info.go | 5 +- utils/common.go | 10 +- utils/constants.go | 2 +- utils/gitee_client.go | 165 +++++++++++++----- 34 files changed, 272 insertions(+), 114 deletions(-) diff --git a/Makefile b/Makefile index c9e0f18..f2084a3 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # Makefile for cross-platform build BINARY_NAME = mcp-gitee-ent -NPM_VERSION = 0.1.3 +NPM_VERSION = 0.1.4 GO = go OSES = darwin linux windows ARCHS = amd64 arm64 diff --git a/main.go b/main.go index a872fa2..b21add7 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,16 @@ var ( enabledToolsetsFlag string ) + +// wrapOptionHandler creates a standard ToolHandlerFunc from an OptionHandlerFunc, +// allowing predefined options to be passed during registration. +func wrapOptionHandler(handler utils.OptionHandlerFunc, opts ...utils.Option) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Call the original handler, passing the captured options. + return handler(ctx, request, opts...) + } +} + func newMCPServer() *server.MCPServer { return server.NewMCPServer( "mcp-gitee-ent", @@ -40,10 +50,10 @@ func newMCPServer() *server.MCPServer { ) } -func addTool(s *server.MCPServer, tool mcp.Tool, handleFunc server.ToolHandlerFunc) { +func addTool(s *server.MCPServer, tool mcp.Tool, handleFunc utils.OptionHandlerFunc) { enabledToolsets := getEnabledToolsets() if len(enabledToolsets) == 0 { - s.AddTool(tool, handleFunc) + s.AddTool(tool, wrapOptionHandler(handleFunc)) return } @@ -53,7 +63,7 @@ func addTool(s *server.MCPServer, tool mcp.Tool, handleFunc server.ToolHandlerFu for _, keepTool := range enabledToolsets { if tool.Name == keepTool { - s.AddTool(tool, handleFunc) + s.AddTool(tool, wrapOptionHandler(handleFunc)) return } } @@ -74,7 +84,7 @@ func disableTools(s *server.MCPServer) { } func addTools(s *server.MCPServer) { - //Issues + // Issues addTool(s, issues.ListIssuesTool, issues.ListIssuesHandleFunc) addTool(s, issues.CreateIssueTool, issues.CreateIssueHandleFunc) addTool(s, issues.GetIssueDetailTool, issues.GetIssueDetailHandleFunc) @@ -104,23 +114,23 @@ func addTools(s *server.MCPServer) { // Labels addTool(s, labels.ListEnterpriseLabelsTool, labels.ListEnterpriseLabelsHandleFunc) - // IssueTypes + // Issue Types addTool(s, issue_types.ListIssueTypesTool, issue_types.ListIssueTypesHandleFunc) - // IssueStates + // Issue States addTool(s, issue_states.ListIssueTypeStatesTool, issue_states.ListIssueTypeStatesHandleFunc) - //Users + // Users addTool(s, user.GetUserInfoTool, user.GetUserInfoHandleFunc) // Programs addTool(s, programs.ListProgramsTool, programs.ListProgramsHandleFunc) - // ScrumSprints + // Scrum Sprints addTool(s, scrum_sprints.CreateScrumSprintTool, scrum_sprints.CreateScrumSprintHandleFunc) addTool(s, scrum_sprints.ListScrumSprintsTool, scrum_sprints.ListScrumSprintsHandleFunc) - // ScrumVersions + // Scrum Versions addTool(s, scrum_versions.ListScrumVersionsTool, scrum_versions.ListScrumVersionsHandleFunc) // Members diff --git a/operations/enterprises/list_enterprises.go b/operations/enterprises/list_enterprises.go index 076ddbb..e04660f 100644 --- a/operations/enterprises/list_enterprises.go +++ b/operations/enterprises/list_enterprises.go @@ -16,11 +16,14 @@ var ListEnterprisesTool = mcp.NewTool(ListEnterprises, mcp.WithDescription("List user's enterprises"), ) -func ListEnterprisesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListEnterprisesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { apiUrl := "/list" - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithPayload(request.Params.Arguments)) + + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.BasicEnterprise]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/groups/list_group.go b/operations/groups/list_group.go index f870bf0..06655a3 100644 --- a/operations/groups/list_group.go +++ b/operations/groups/list_group.go @@ -49,7 +49,7 @@ var ListEntGroupsTool = mcp.NewTool(ListEntGroup, ), ) -func ListEntGroupsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListEntGroupsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -60,9 +60,11 @@ func ListEntGroupsHandleFunc(ctx context.Context, request mcp.CallToolRequest) ( return mcp.NewToolResultError(err.Error()), err } apiUrl := fmt.Sprintf("/%d/groups", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) // Handle response data := types.PagedResponse[types.Group]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issue_states/list_issue_type_states.go b/operations/issue_states/list_issue_type_states.go index 27e91ad..efcf644 100644 --- a/operations/issue_states/list_issue_type_states.go +++ b/operations/issue_states/list_issue_type_states.go @@ -46,7 +46,7 @@ var ListIssueTypeStatesTool = mcp.NewTool(ListIssueTypeStates, ), ) -func ListIssueTypeStatesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListIssueTypeStatesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { if checkResult, err := utils.CheckRequired(request.Params.Arguments, "issue_type_id"); err != nil { return checkResult, err } @@ -63,7 +63,9 @@ func ListIssueTypeStatesHandleFunc(ctx context.Context, request mcp.CallToolRequ } apiUrl := fmt.Sprintf("/%d/issue_types/%d/issue_states", enterpriseID, issueTypeID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.IssueState]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issue_types/list_issue_types.go b/operations/issue_types/list_issue_types.go index 7cac8ff..bc752d1 100644 --- a/operations/issue_types/list_issue_types.go +++ b/operations/issue_types/list_issue_types.go @@ -59,7 +59,7 @@ var ListIssueTypesTool = mcp.NewTool(ListIssueTypes, ), ) -func ListIssueTypesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListIssueTypesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -70,8 +70,10 @@ func ListIssueTypesHandleFunc(ctx context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(err.Error()), err } apiUrl := fmt.Sprintf("/%d/issue_types/enterprise_issue_types", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.IssueType]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issues/comment_issue.go b/operations/issues/comment_issue.go index 6947ab7..0feb040 100644 --- a/operations/issues/comment_issue.go +++ b/operations/issues/comment_issue.go @@ -32,7 +32,7 @@ var CommentIssueTool = mcp.NewTool(CommentIssue, ), ) -func CommentIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CommentIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "issue_id", "body"); err != nil { return checkResult, err @@ -47,8 +47,10 @@ func CommentIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (* request.Params.Arguments["qt"] = "ident" apiUrl := fmt.Sprintf("/%d/issues/%s/notes", enterpriseID, issueID) - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.IssueComment{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issues/create_issue.go b/operations/issues/create_issue.go index a65c35f..2ca015c 100644 --- a/operations/issues/create_issue.go +++ b/operations/issues/create_issue.go @@ -116,7 +116,7 @@ var CreateIssueTool = mcp.NewTool(CreateIssue, ), ) -func CreateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CreateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "title"); err != nil { return checkResult, err @@ -128,7 +128,8 @@ func CreateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*m } apiUrl := fmt.Sprintf("/%d/issues", enterpriseID) - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.BasicIssue{} return giteeClient.HandleMCPResult(&data) diff --git a/operations/issues/get_issue_detail.go b/operations/issues/get_issue_detail.go index 86b5654..23bd394 100644 --- a/operations/issues/get_issue_detail.go +++ b/operations/issues/get_issue_detail.go @@ -27,7 +27,7 @@ var GetIssueDetailTool = mcp.NewTool(GetIssueDetail, ), ) -func GetIssueDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func GetIssueDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "issue_id"); err != nil { return checkResult, err @@ -41,9 +41,11 @@ func GetIssueDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest) request.Params.Arguments["qt"] = "ident" apiUrl := fmt.Sprintf("/%d/issues/%s", enterpriseID, issueID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.IssueDetail{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issues/list_issue_comments.go b/operations/issues/list_issue_comments.go index 749cf34..c02a1cc 100644 --- a/operations/issues/list_issue_comments.go +++ b/operations/issues/list_issue_comments.go @@ -47,7 +47,7 @@ var ListIssueCommentsTool = mcp.NewTool(ListIssueComments, ), ) -func ListIssueCommentsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListIssueCommentsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "issue_id"); err != nil { return checkResult, err @@ -62,8 +62,10 @@ func ListIssueCommentsHandleFunc(ctx context.Context, request mcp.CallToolReques request.Params.Arguments["qt"] = "ident" apiUrl := fmt.Sprintf("/%d/issues/%s/notes", enterpriseID, issueID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.IssueComment]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/issues/list_issues.go b/operations/issues/list_issues.go index de3550f..0de92c8 100644 --- a/operations/issues/list_issues.go +++ b/operations/issues/list_issues.go @@ -126,7 +126,7 @@ var ListIssuesTool = mcp.NewTool(ListIssues, ), ) -func ListIssuesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListIssuesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err } @@ -138,7 +138,10 @@ func ListIssuesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mc request.Params.Arguments["show_scrum_sprints"] = true apiUrl := fmt.Sprintf("/%d/issues", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.BasicIssue]{} diff --git a/operations/issues/update_issue.go b/operations/issues/update_issue.go index dad2417..1c2c314 100644 --- a/operations/issues/update_issue.go +++ b/operations/issues/update_issue.go @@ -103,7 +103,7 @@ var UpdateIssueTool = mcp.NewTool(UpdateIssue, ), ) -func UpdateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func UpdateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "issue_id"); err != nil { return checkResult, err @@ -117,8 +117,10 @@ func UpdateIssueHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*m request.Params.Arguments["qt"] = "ident" apiUrl := fmt.Sprintf("/%d/issues/%s", enterpriseID, issueIDArg) - giteeClient := utils.NewGiteeClient("PUT", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("PUT", apiUrl, opts...) data := types.BasicIssue{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/labels/list_labels.go b/operations/labels/list_labels.go index d163f52..82a9211 100644 --- a/operations/labels/list_labels.go +++ b/operations/labels/list_labels.go @@ -45,7 +45,7 @@ var ListEnterpriseLabelsTool = mcp.NewTool(ListEnterpriseLabels, ), ) -func ListEnterpriseLabelsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListEnterpriseLabelsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -57,8 +57,10 @@ func ListEnterpriseLabelsHandleFunc(ctx context.Context, request mcp.CallToolReq } apiUrl := fmt.Sprintf("/%d/labels", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.Label]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/members/list_ent_members.go b/operations/members/list_ent_members.go index d3533e8..d898de1 100644 --- a/operations/members/list_ent_members.go +++ b/operations/members/list_ent_members.go @@ -67,7 +67,7 @@ var ListEntMembersTool = mcp.NewTool(ListEntMembers, ), ) -func ListEntMembersHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListEntMembersHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -79,9 +79,11 @@ func ListEntMembersHandleFunc(ctx context.Context, request mcp.CallToolRequest) } apiUrl := fmt.Sprintf("/%d/members", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) // Handle response data := types.PagedResponse[types.EnterpriseMember]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/programs/list_programs.go b/operations/programs/list_programs.go index c5b7b13..d6b1099 100644 --- a/operations/programs/list_programs.go +++ b/operations/programs/list_programs.go @@ -66,7 +66,7 @@ var ListProgramsTool = mcp.NewTool(ListPrograms, ), ) -func ListProgramsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListProgramsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -78,8 +78,10 @@ func ListProgramsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (* } apiUrl := fmt.Sprintf("/%d/programs", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.Program]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/comment_pull.go b/operations/pulls/comment_pull.go index 42a3bf9..656688f 100644 --- a/operations/pulls/comment_pull.go +++ b/operations/pulls/comment_pull.go @@ -41,7 +41,7 @@ var CommentPullTool = mcp.NewTool(CommentEntPull, ), ) -func CommentPullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CommentPullHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id", "body"); err != nil { return checkResult, err @@ -64,8 +64,10 @@ func CommentPullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*m if !utils.IsAllDigits(projectIDArg.(string)) { request.Params.Arguments["qt"] = "path" } - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.PullComment{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/create_pull.go b/operations/pulls/create_pull.go index 0aa864d..55603a4 100644 --- a/operations/pulls/create_pull.go +++ b/operations/pulls/create_pull.go @@ -62,7 +62,7 @@ var CreateEntPullTool = mcp.NewTool(CreateEntPull, ), ) -func CreateEntPullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CreateEntPullHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "source_branch", "target_branch", "title"); err != nil { return checkResult, err @@ -79,8 +79,10 @@ func CreateEntPullHandleFunc(ctx context.Context, request mcp.CallToolRequest) ( if !utils.IsAllDigits(projectIDArg.(string)) { request.Params.Arguments["qt"] = "path" } - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.PullDetail{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/get_pull_detail.go b/operations/pulls/get_pull_detail.go index 6892974..8a48fd5 100644 --- a/operations/pulls/get_pull_detail.go +++ b/operations/pulls/get_pull_detail.go @@ -32,7 +32,7 @@ var GetEntPullDetailTool = mcp.NewTool(GetEntPullDetail, ), ) -func GetPullDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func GetPullDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id"); err != nil { return checkResult, err @@ -54,8 +54,10 @@ func GetPullDetailHandleFunc(ctx context.Context, request mcp.CallToolRequest) ( request.Params.Arguments["pr_qt"] = "iid" apiUrl := fmt.Sprintf("/%d/projects/%s/pull_requests/%d", enterpriseID, url.QueryEscape(projectIDArg.(string)), pullRequestID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PullDetail{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/get_pull_diff.go b/operations/pulls/get_pull_diff.go index 21fa608..7334dd3 100644 --- a/operations/pulls/get_pull_diff.go +++ b/operations/pulls/get_pull_diff.go @@ -32,7 +32,7 @@ var GetPullDiffTool = mcp.NewTool(GetEntPullDiff, ), ) -func GetPullDiffHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func GetPullDiffHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id"); err != nil { return checkResult, err @@ -55,8 +55,10 @@ func GetPullDiffHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*m request.Params.Arguments["pr_qt"] = "iid" apiUrl := fmt.Sprintf("/%d/projects/%s/pull_requests/%d/diff", enterpriseID, url.QueryEscape(projectIDArg.(string)), pullRequestID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PullDiff{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/list_pull_comments.go b/operations/pulls/list_pull_comments.go index c7c0512..ab5b40c 100644 --- a/operations/pulls/list_pull_comments.go +++ b/operations/pulls/list_pull_comments.go @@ -42,7 +42,7 @@ var ListPullCommentsTool = mcp.NewTool(ListEntPullComments, ), ) -func ListPullCommentsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListPullCommentsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id"); err != nil { return checkResult, err @@ -65,8 +65,10 @@ func ListPullCommentsHandleFunc(ctx context.Context, request mcp.CallToolRequest if !utils.IsAllDigits(projectIDArg.(string)) { request.Params.Arguments["qt"] = "path" } - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.PullComment]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/list_pulls.go b/operations/pulls/list_pulls.go index e5df2e5..1d02a7d 100644 --- a/operations/pulls/list_pulls.go +++ b/operations/pulls/list_pulls.go @@ -104,7 +104,7 @@ var ListEntPullsTool = mcp.NewTool(ListEntPulls, ), ) -func ListEntPullsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListEntPullsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err @@ -117,9 +117,11 @@ func ListEntPullsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (* apiUrl := fmt.Sprintf("/%d/pull_requests", enterpriseID) request.Params.Arguments["pr_qt"] = "iid" - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) // Handle response data := types.PagedResponse[types.BasicPull]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/pulls/merge_pull.go b/operations/pulls/merge_pull.go index 10042dc..4579c43 100644 --- a/operations/pulls/merge_pull.go +++ b/operations/pulls/merge_pull.go @@ -45,7 +45,7 @@ var MergePullTool = mcp.NewTool(MergeEntPull, ), ) -func MergePullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func MergePullHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id"); err != nil { return checkResult, err @@ -68,6 +68,8 @@ func MergePullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp if !utils.IsAllDigits(projectIDArg.(string)) { request.Params.Arguments["qt"] = "path" } - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) return giteeClient.HandleMCPResult(nil) } + diff --git a/operations/pulls/update_pull.go b/operations/pulls/update_pull.go index bb5059e..057b540 100644 --- a/operations/pulls/update_pull.go +++ b/operations/pulls/update_pull.go @@ -49,7 +49,7 @@ var UpdatePullTool = mcp.NewTool(UpdateEntPull, ), ) -func UpdatePullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func UpdatePullHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "pull_request_id"); err != nil { return checkResult, err @@ -71,8 +71,10 @@ func UpdatePullHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mc } apiUrl := fmt.Sprintf("/%d/projects/%s/pull_requests/%d", enterpriseID, url.QueryEscape(projectIDArg.(string)), pullRequestID) - giteeClient := utils.NewGiteeClient("PUT", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("PUT", apiUrl, opts...) data := types.PullDetail{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/repository/create_release.go b/operations/repository/create_release.go index 65b5bad..e596303 100644 --- a/operations/repository/create_release.go +++ b/operations/repository/create_release.go @@ -52,7 +52,7 @@ var CreateReleaseTool = mcp.NewTool(CreateRelease, ), ) -func CreateReleaseHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CreateReleaseHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id", "release_tag_version", "release_title", "release_description"); err != nil { return checkResult, err } @@ -70,8 +70,10 @@ func CreateReleaseHandleFunc(ctx context.Context, request mcp.CallToolRequest) ( apiUrl := fmt.Sprintf("/%d/projects/%s/releases", enterpriseID, url.QueryEscape(projectIDArg.(string))) payload := utils.ConvertToHash(request.Params.Arguments, "release", "tag_version", "title", "ref", "description", "release_type") - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(payload)) + opts = append(opts, utils.WithPayload(payload)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.ReleaseDetail{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/repository/create_repository.go b/operations/repository/create_repository.go index 1625175..8a40895 100644 --- a/operations/repository/create_repository.go +++ b/operations/repository/create_repository.go @@ -42,11 +42,11 @@ var CreateRepositoryTool = mcp.NewTool(CreateEnterpriseProject, ), mcp.WithNumber( "project_public", - mcp.Description("Public visibility (0: Private, 1: Public)"), + mcp.Description("Whether public: 0: Private, 1: Public, 2: Internal Open"), ), mcp.WithNumber( "project_outsourced", - mcp.Description("Outsourced status (0: No, 1: Yes)"), + mcp.Description("Whether outsourced: 0: No, 1: Yes"), ), mcp.WithString( "project_program_ids", @@ -74,7 +74,7 @@ var CreateRepositoryTool = mcp.NewTool(CreateEnterpriseProject, ), ) -func CreateRepositoryHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CreateRepositoryHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_name", "project_namespace_path", "project_path"); err != nil { return checkResult, err } @@ -87,8 +87,10 @@ func CreateRepositoryHandleFunc(ctx context.Context, request mcp.CallToolRequest apiUrl := fmt.Sprintf("/%d/projects", enterpriseID) payload := utils.ConvertToHash(request.Params.Arguments, "project", "name", "namespace_path", "path", "description", "public", "outsourced", "program_ids", "member_ids") - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(payload)) + opts = append(opts, utils.WithPayload(payload)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.Repository{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/repository/list_releases.go b/operations/repository/list_releases.go index c8958d2..d8094a1 100644 --- a/operations/repository/list_releases.go +++ b/operations/repository/list_releases.go @@ -37,7 +37,7 @@ var ListReleasesTool = mcp.NewTool(ListReleases, ), ) -func ListReleasesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListReleasesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "project_id"); err != nil { return checkResult, err @@ -55,8 +55,10 @@ func ListReleasesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (* apiUrl := fmt.Sprintf("/%d/projects/%s/releases", enterpriseID, url.QueryEscape(projectIDArg.(string))) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.ReleaseDetail]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/repository/list_repository.go b/operations/repository/list_repository.go index 18d04a7..c04c8b2 100644 --- a/operations/repository/list_repository.go +++ b/operations/repository/list_repository.go @@ -86,7 +86,7 @@ var ListRepositoriesTool = mcp.NewTool(ListEnterpriseProjects, ), ) -func ListRepositoriesHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListRepositoriesHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { if checkResult, err := utils.CheckRequired(request.Params.Arguments); err != nil { return checkResult, err } @@ -98,7 +98,8 @@ func ListRepositoriesHandleFunc(ctx context.Context, request mcp.CallToolRequest apiUrl := fmt.Sprintf("/%d/projects", enterpriseID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.Repository]{} diff --git a/operations/scrum_sprints/create_scrum_sprint.go b/operations/scrum_sprints/create_scrum_sprint.go index fd760b8..682b4a5 100644 --- a/operations/scrum_sprints/create_scrum_sprint.go +++ b/operations/scrum_sprints/create_scrum_sprint.go @@ -54,9 +54,9 @@ var CreateScrumSprintTool = mcp.NewTool(CreateScrumSprint, ), ) -func CreateScrumSprintHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func CreateScrumSprintHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters - if checkResult, err := utils.CheckRequired(request.Params.Arguments, "program_id", "title", "started_at", "finished_at"); err != nil { + if checkResult, err := utils.CheckRequired(request.Params.Arguments, "program_id", "title", "assignee_id", "started_at", "finished_at"); err != nil { return checkResult, err } enterpriseIDArg := request.Params.Arguments["enterprise_id"] @@ -71,8 +71,10 @@ func CreateScrumSprintHandleFunc(ctx context.Context, request mcp.CallToolReques } apiUrl := fmt.Sprintf("/%d/programs/%d/scrum_sprints", enterpriseID, programID) - giteeClient := utils.NewGiteeClient("POST", apiUrl, utils.WithPayload(request.Params.Arguments)) + opts = append(opts, utils.WithPayload(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("POST", apiUrl, opts...) data := types.ScrumSprint{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/scrum_sprints/list_scrum_sprints.go b/operations/scrum_sprints/list_scrum_sprints.go index d16b487..84e07a4 100644 --- a/operations/scrum_sprints/list_scrum_sprints.go +++ b/operations/scrum_sprints/list_scrum_sprints.go @@ -49,7 +49,7 @@ var ListScrumSprintsTool = mcp.NewTool(ListScrumSprints, ), ) -func ListScrumSprintsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListScrumSprintsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "program_id"); err != nil { return checkResult, err @@ -66,8 +66,10 @@ func ListScrumSprintsHandleFunc(ctx context.Context, request mcp.CallToolRequest } apiUrl := fmt.Sprintf("/%d/programs/%d/scrum_sprints", enterpriseID, programID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.ScrumSprint]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/scrum_versions/list_scrum_versions.go b/operations/scrum_versions/list_scrum_versions.go index afbc6c9..08d3ec5 100644 --- a/operations/scrum_versions/list_scrum_versions.go +++ b/operations/scrum_versions/list_scrum_versions.go @@ -45,7 +45,7 @@ var ListScrumVersionsTool = mcp.NewTool(ListScrumVersions, ), ) -func ListScrumVersionsHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func ListScrumVersionsHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { // Validate required parameters if checkResult, err := utils.CheckRequired(request.Params.Arguments, "program_id"); err != nil { return checkResult, err @@ -63,8 +63,10 @@ func ListScrumVersionsHandleFunc(ctx context.Context, request mcp.CallToolReques } apiUrl := fmt.Sprintf("/%d/programs/%d/scrum_versions", enterpriseID, programID) - giteeClient := utils.NewGiteeClient("GET", apiUrl, utils.WithQuery(request.Params.Arguments)) + opts = append(opts, utils.WithQuery(request.Params.Arguments)) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.PagedResponse[types.ScrumVersion]{} return giteeClient.HandleMCPResult(&data) } + diff --git a/operations/user/get_user_info.go b/operations/user/get_user_info.go index 7a91538..979c8d6 100644 --- a/operations/user/get_user_info.go +++ b/operations/user/get_user_info.go @@ -13,11 +13,12 @@ const ( var GetUserInfoTool = mcp.NewTool(GetUserInfo, mcp.WithDescription("Get user info")) -func GetUserInfoHandleFunc(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func GetUserInfoHandleFunc(ctx context.Context, request mcp.CallToolRequest, opts ...utils.Option) (*mcp.CallToolResult, error) { apiUrl := "/users" - giteeClient := utils.NewGiteeClient("GET", apiUrl) + giteeClient := utils.NewGiteeClient("GET", apiUrl, opts...) data := types.UserInfo{} return giteeClient.HandleMCPResult(&data) } + diff --git a/utils/common.go b/utils/common.go index 3bb848a..265087b 100644 --- a/utils/common.go +++ b/utils/common.go @@ -1,6 +1,14 @@ package utils -import "regexp" +import ( + "context" + "regexp" + + "github.com/mark3labs/mcp-go/mcp" +) + +// OptionHandlerFunc defines the signature for handlers that accept utils.Option. +type OptionHandlerFunc func(context.Context, mcp.CallToolRequest, ...Option) (*mcp.CallToolResult, error) func IsAllDigits(s string) bool { pattern := regexp.MustCompile(`^[0-9]+$`) diff --git a/utils/constants.go b/utils/constants.go index d3b38b1..70d7c2c 100644 --- a/utils/constants.go +++ b/utils/constants.go @@ -2,5 +2,5 @@ package utils var ( // Version gitee mcp ent server version - Version = "0.1.3" + Version = "0.1.4" ) diff --git a/utils/gitee_client.go b/utils/gitee_client.go index d58d395..52e7511 100644 --- a/utils/gitee_client.go +++ b/utils/gitee_client.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "gitee.com/oschina/mcp-gitee-ent/operations/types" - "io/ioutil" + "io" "net/http" "net/url" "os" @@ -51,34 +51,83 @@ func GetApiBase() string { return DefaultApiBase } -type GiteeClient struct { - Url string - Method string - Payload interface{} - Headers map[string]string - Response *http.Response - parsedUrl *url.URL - Query map[string]string +type Authorizer interface { + Authorize(req *http.Request) error } -type Option func(client *GiteeClient) +type BearerTokenAuthorizer struct{} -func NewGiteeClient(method, urlString string, opts ...Option) *GiteeClient { - urlString = GetApiBase() + urlString - parsedUrl, err := url.Parse(urlString) - if err != nil { - panic(err) +func (b *BearerTokenAuthorizer) Authorize(req *http.Request) error { + accessToken := GetGiteeAccessToken() + if accessToken == "" { + return NewAuthError() } + req.Header.Set("Authorization", "Bearer "+accessToken) + return nil +} + +type CookieAuthorizer struct { + Cookie string +} + +func (c *CookieAuthorizer) Authorize(req *http.Request) error { + if c.Cookie == "" { + return NewAuthError() + } + req.Header.Set("Cookie", c.Cookie) + return nil +} + +type GiteeClient struct { + Url string + Method string + Payload interface{} + Headers map[string]string + Response *http.Response + parsedUrl *url.URL + Query map[string]string + httpClient *http.Client + authorizer Authorizer + apiBase string // Added apiBase field +} +type Option func(client *GiteeClient) + +func NewGiteeClient(method, urlPath string, opts ...Option) *GiteeClient { + // Initialize client with defaults, including apiBase from GetApiBase() client := &GiteeClient{ - Method: method, - Url: parsedUrl.String(), - parsedUrl: parsedUrl, + Method: method, + httpClient: http.DefaultClient, + authorizer: &BearerTokenAuthorizer{}, + apiBase: GetApiBase(), } + // Apply options. This allows WithApiBase to override the default apiBase, + // and WithQuery to populate client.Query. for _, opt := range opts { opt(client) } + + // Construct the full URL using the client's potentially updated apiBase + fullURL := client.apiBase + urlPath + parsedUrl, err := url.Parse(fullURL) + if err != nil { + panic(fmt.Errorf("failed to parse URL '%s': %w", fullURL, err)) + } + client.parsedUrl = parsedUrl // Store parsed URL object + + // Apply query parameters from client.Query (populated by WithQuery option) + if client.Query != nil { + queryParams := client.parsedUrl.Query() + for k, v := range client.Query { + queryParams.Set(k, v) + } + client.parsedUrl.RawQuery = queryParams.Encode() + } + + // Set the final URL string including any query parameters + client.Url = client.parsedUrl.String() + return client } @@ -87,27 +136,23 @@ func WithQuery(query map[string]interface{}) Option { return func(client *GiteeClient) { parsedQuery := make(map[string]string) if query != nil { - queryParams := client.parsedUrl.Query() for k, v := range query { parsedValue := "" - switch v.(type) { + switch val := v.(type) { case string: - parsedValue = v.(string) + parsedValue = val case int: - parsedValue = strconv.Itoa(v.(int)) + parsedValue = strconv.Itoa(val) case float32, float64: - parsedValue = fmt.Sprintf("%.f", v) + parsedValue = fmt.Sprintf("%v", val) case bool: - parsedValue = strconv.FormatBool(v.(bool)) + parsedValue = strconv.FormatBool(val) } if parsedValue != "" { - queryParams.Set(k, parsedValue) parsedQuery[k] = parsedValue } } - client.parsedUrl.RawQuery = queryParams.Encode() } - client.Url = client.parsedUrl.String() client.Query = parsedQuery } } @@ -124,6 +169,31 @@ func WithHeaders(headers map[string]string) Option { } } +func WithHTTPClient(httpClient *http.Client) Option { + return func(client *GiteeClient) { + if httpClient != nil { + client.httpClient = httpClient + } + } +} + +func WithAuthorizer(authorizer Authorizer) Option { + return func(client *GiteeClient) { + if authorizer != nil { + client.authorizer = authorizer + } + } +} + +// WithApiBase now modifies the client's internal apiBase field +func WithApiBase(url string) Option { + return func(client *GiteeClient) { + if url != "" { + client.apiBase = url + } + } +} + func (g *GiteeClient) SetHeaders(headers map[string]string) *GiteeClient { g.Headers = headers return g @@ -131,28 +201,41 @@ func (g *GiteeClient) SetHeaders(headers map[string]string) *GiteeClient { func (g *GiteeClient) Do() (*GiteeClient, error) { g.Response = nil - _payload, _ := json.Marshal(g.Payload) - req, err := http.NewRequest(g.Method, g.Url, bytes.NewReader(_payload)) + var requestBody []byte + var err error + + if g.Payload != nil { + requestBody, err = json.Marshal(g.Payload) + if err != nil { + return nil, NewInternalError(fmt.Errorf("failed to marshal payload: %w", err)) + } + } + + req, err := http.NewRequest(g.Method, g.Url, bytes.NewReader(requestBody)) if err != nil { - return nil, NewInternalError(err) + return nil, NewInternalError(fmt.Errorf("failed to create request: %w", err)) } req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", "mcp-gitee-ent "+Version+" Go/"+runtime.GOOS+"/"+runtime.GOARCH+"/"+runtime.Version()) - accessToken := GetGiteeAccessToken() - if accessToken == "" { - return nil, NewAuthError() - } - - req.Header.Set("Authorization", "Bearer "+accessToken) - for key, value := range g.Headers { req.Header.Set(key, value) } - client := &http.Client{} - resp, err := client.Do(req) + // Apply authorization + if g.authorizer != nil { + if err := g.authorizer.Authorize(req); err != nil { + if IsAuthError(err) { + return nil, err + } + return nil, NewInternalError(fmt.Errorf("authorization failed: %w", err)) + } + } else { + return nil, NewAuthError() + } + + resp, err := g.httpClient.Do(req) if err != nil { return g, NewNetworkError(err) } @@ -161,7 +244,7 @@ func (g *GiteeClient) Do() (*GiteeClient, error) { // 检查响应状态码 if !g.IsSuccess() { - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) return g, NewAPIError(resp.StatusCode, body) } @@ -192,7 +275,7 @@ func (g *GiteeClient) IsFail() bool { } func (g *GiteeClient) GetRespBody() ([]byte, error) { - return ioutil.ReadAll(g.Response.Body) + return io.ReadAll(g.Response.Body) } func (g *GiteeClient) HandleMCPResult(object any) (*mcp.CallToolResult, error) { -- Gitee