diff --git a/models/perm/access/access.go b/models/perm/access/access.go index 6433c4675c..acc34c434e 100644 --- a/models/perm/access/access.go +++ b/models/perm/access/access.go @@ -91,43 +91,80 @@ func updateUserAccess(accessMap map[int64]*userAccess, user *user_model.User, mo } } -// FIXME: do cross-comparison so reduce deletions and additions to the minimum? +// refreshAccesses updates the repository's access records in the database by comparing the provided accessMap +// with existing records. It minimizes DB operations by performing selective inserts, updates, and deletes +// instead of removing all existing records and re-adding them. func refreshAccesses(ctx context.Context, repo *repo_model.Repository, accessMap map[int64]*userAccess) (err error) { - minMode := perm.AccessModeRead + minModeToKeep := perm.AccessModeRead if err := repo.LoadOwner(ctx); err != nil { return fmt.Errorf("LoadOwner: %w", err) } - // If the repo isn't private and isn't owned by a organization, + // If the repo isn't private and isn't owned by an organization, // increase the minMode to Write. if !repo.IsPrivate && !repo.Owner.IsOrganization() { - minMode = perm.AccessModeWrite + minModeToKeep = perm.AccessModeWrite } - newAccesses := make([]Access, 0, len(accessMap)) + // Query existing accesses for cross-comparison + var existingAccesses []Access + if err := db.GetEngine(ctx).Where(builder.Eq{"repo_id": repo.ID}).Find(&existingAccesses); err != nil { + return fmt.Errorf("find existing accesses: %w", err) + } + existingMap := make(map[int64]perm.AccessMode, len(existingAccesses)) + for _, a := range existingAccesses { + existingMap[a.UserID] = a.Mode + } + + var toDelete []int64 + var toInsert, toUpdate []Access + + // Determine changes for userID, ua := range accessMap { - if ua.Mode < minMode && !ua.User.IsRestricted { - continue + if ua.Mode < minModeToKeep && !ua.User.IsRestricted { + // No explicit access record needed (handled by default permissions, e.g., public repo access) + if _, exists := existingMap[userID]; exists { + toDelete = append(toDelete, userID) + } + } else { + desiredMode := ua.Mode + if existingMode, exists := existingMap[userID]; exists { + if existingMode != desiredMode { + toUpdate = append(toUpdate, Access{UserID: userID, RepoID: repo.ID, Mode: desiredMode}) + } + } else { + toInsert = append(toInsert, Access{UserID: userID, RepoID: repo.ID, Mode: desiredMode}) + } } - - newAccesses = append(newAccesses, Access{ - UserID: userID, - RepoID: repo.ID, - Mode: ua.Mode, - }) + delete(existingMap, userID) } - // Delete old accesses and insert new ones for repository. - if _, err = db.DeleteByBean(ctx, &Access{RepoID: repo.ID}); err != nil { - return fmt.Errorf("delete old accesses: %w", err) - } - if len(newAccesses) == 0 { - return nil + // Remaining in existingMap should be deleted + for userID := range existingMap { + toDelete = append(toDelete, userID) } - if err = db.Insert(ctx, newAccesses); err != nil { - return fmt.Errorf("insert new accesses: %w", err) + // Execute deletions + if len(toDelete) > 0 { + if _, err = db.GetEngine(ctx).In("user_id", toDelete).And("repo_id = ?", repo.ID).Delete(&Access{}); err != nil { + return fmt.Errorf("delete accesses: %w", err) + } } + + // Execute updates + for _, u := range toUpdate { + if _, err = db.GetEngine(ctx).Where("user_id = ? AND repo_id = ?", u.UserID, repo.ID).Cols("mode").Update(&Access{Mode: u.Mode}); err != nil { + return fmt.Errorf("update access for user %d: %w", u.UserID, err) + } + } + + // Execute insertions + if len(toInsert) > 0 { + if err = db.Insert(ctx, toInsert); err != nil { + return fmt.Errorf("insert new accesses: %w", err) + } + } + return nil } diff --git a/models/perm/access/access_test.go b/models/perm/access/access_test.go index 15d18b368c..148a02efa3 100644 --- a/models/perm/access/access_test.go +++ b/models/perm/access/access_test.go @@ -15,11 +15,20 @@ import ( "code.gitea.io/gitea/modules/setting" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestAccessLevel(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) +func TestAccess(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + t.Run("AccessLevel", testAccessLevel) + t.Run("HasAccess", testHasAccess) + t.Run("RecalculateAccesses", testRecalculateAccesses) + t.Run("RecalculateAccesses2", testRecalculateAccesses2) + t.Run("RecalculateAccessesUpdateMode", testRecalculateAccessesUpdateMode) + t.Run("RecalculateAccessesRemoveAccess", testRecalculateAccessesRemoveAccess) +} +func testAccessLevel(t *testing.T) { user2 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) user5 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 5}) user29 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 29}) @@ -75,9 +84,7 @@ func TestAccessLevel(t *testing.T) { assert.Equal(t, perm_model.AccessModeRead, level) } -func TestHasAccess(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - +func testHasAccess(t *testing.T) { user1 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) user2 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 5}) // A public repository owned by User 2 @@ -101,9 +108,8 @@ func TestHasAccess(t *testing.T) { assert.NoError(t, err) } -func TestRepository_RecalculateAccesses(t *testing.T) { +func testRecalculateAccesses(t *testing.T) { // test with organization repo - assert.NoError(t, unittest.PrepareTestDatabase()) repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 3}) assert.NoError(t, repo1.LoadOwner(t.Context())) @@ -118,9 +124,8 @@ func TestRepository_RecalculateAccesses(t *testing.T) { assert.Equal(t, perm_model.AccessModeOwner, access.Mode) } -func TestRepository_RecalculateAccesses2(t *testing.T) { +func testRecalculateAccesses2(t *testing.T) { // test with non-organization repo - assert.NoError(t, unittest.PrepareTestDatabase()) repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4}) assert.NoError(t, repo1.LoadOwner(t.Context())) @@ -132,3 +137,67 @@ func TestRepository_RecalculateAccesses2(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func testRecalculateAccessesUpdateMode(t *testing.T) { + // Test the update path in refreshAccesses optimization + // Scenario: User's access mode changes from Read to Write + repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4}) + assert.NoError(t, repo.LoadOwner(t.Context())) + + // Verify initial access mode + _ = db.Insert(t.Context(), &repo_model.Collaboration{UserID: 4, RepoID: 4, Mode: perm_model.AccessModeWrite}) + _ = db.Insert(t.Context(), &access_model.Access{UserID: 4, RepoID: 4, Mode: perm_model.AccessModeWrite}) + initialAccess := &access_model.Access{UserID: 4, RepoID: 4} + has, err := db.GetEngine(t.Context()).Get(initialAccess) + assert.NoError(t, err) + assert.True(t, has) + initialMode := initialAccess.Mode + + // Change collaboration mode to trigger update path + newMode := perm_model.AccessModeAdmin + assert.NotEqual(t, initialMode, newMode, "New mode should differ from initial mode") + + _, err = db.GetEngine(t.Context()). + Where("user_id = ? AND repo_id = ?", 4, 4). + Cols("mode"). + Update(&repo_model.Collaboration{Mode: newMode}) + assert.NoError(t, err) + + // Recalculate accesses - should UPDATE existing access, not delete+insert + assert.NoError(t, access_model.RecalculateAccesses(t.Context(), repo)) + + // Verify access was updated, not deleted and re-inserted + updatedAccess := &access_model.Access{UserID: 4, RepoID: 4} + has, err = db.GetEngine(t.Context()).Get(updatedAccess) + assert.NoError(t, err) + assert.True(t, has, "Access should still exist") + assert.Equal(t, newMode, updatedAccess.Mode, "Access mode should be updated to new collaboration mode") +} + +func testRecalculateAccessesRemoveAccess(t *testing.T) { + // Test the delete path in refreshAccesses optimization + // Scenario: Remove a user's collaboration, access should be deleted + repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4}) + assert.NoError(t, repo.LoadOwner(t.Context())) + + // Verify initial access exists + initialAccess := &access_model.Access{UserID: 4, RepoID: 4} + has, err := db.GetEngine(t.Context()).Get(initialAccess) + assert.NoError(t, err) + assert.True(t, has, "Access should exist initially") + + // Remove the collaboration to trigger delete path + _, err = db.GetEngine(t.Context()). + Where("user_id = ? AND repo_id = ?", 4, 4). + Delete(&repo_model.Collaboration{}) + assert.NoError(t, err) + + // Recalculate accesses - should DELETE the access record + assert.NoError(t, access_model.RecalculateAccesses(t.Context(), repo)) + + // Verify access was deleted + removedAccess := &access_model.Access{UserID: 4, RepoID: 4} + has, err = db.GetEngine(t.Context()).Get(removedAccess) + assert.NoError(t, err) + assert.False(t, has, "Access should be deleted after removing collaboration") +}