diff --git a/src/cli/ssh-cache/DiskCache.zig b/src/cli/ssh-cache/DiskCache.zig index 62620ecb0..6214d0429 100644 --- a/src/cli/ssh-cache/DiskCache.zig +++ b/src/cli/ssh-cache/DiskCache.zig @@ -9,7 +9,6 @@ const assert = @import("../../quirks.zig").inlineAssert; const Allocator = std.mem.Allocator; const internal_os = @import("../../os/main.zig"); const xdg = internal_os.xdg; -const TempDir = internal_os.TempDir; const Entry = @import("Entry.zig"); // 512KB - sufficient for approximately 10k entries @@ -125,7 +124,7 @@ pub fn add( break :update .updated; }; - try self.writeCacheFile(alloc, entries, null); + try self.writeCacheFile(entries, null); return result; } @@ -166,7 +165,7 @@ pub fn remove( alloc.free(kv.value.terminfo_version); } - try self.writeCacheFile(alloc, entries, null); + try self.writeCacheFile(entries, null); } /// Check if a hostname exists in the cache. @@ -209,32 +208,30 @@ fn fixupPermissions(file: std.fs.File) (std.fs.File.StatError || std.fs.File.Chm fn writeCacheFile( self: DiskCache, - alloc: Allocator, entries: std.StringHashMap(Entry), expire_days: ?u32, ) !void { - var td: TempDir = try .init(); - defer td.deinit(); + const cache_dir = std.fs.path.dirname(self.path) orelse return error.InvalidCachePath; + const cache_basename = std.fs.path.basename(self.path); - const tmp_file = try td.dir.createFile("ssh-cache", .{ .mode = 0o600 }); - defer tmp_file.close(); - const tmp_path = try td.dir.realpathAlloc(alloc, "ssh-cache"); - defer alloc.free(tmp_path); + var dir = try std.fs.cwd().openDir(cache_dir, .{}); + defer dir.close(); var buf: [1024]u8 = undefined; - var writer = tmp_file.writer(&buf); + var atomic_file = try dir.atomicFile(cache_basename, .{ + .mode = 0o600, + .write_buffer = &buf, + }); + defer atomic_file.deinit(); + var iter = entries.iterator(); while (iter.next()) |kv| { // Only write non-expired entries if (kv.value_ptr.isExpired(expire_days)) continue; - try kv.value_ptr.format(&writer.interface); + try kv.value_ptr.format(&atomic_file.file_writer.interface); } - // Don't forget to flush!! - try writer.interface.flush(); - - // Atomic replace - try std.fs.renameAbsolute(tmp_path, self.path); + try atomic_file.finish(); } /// List all entries in the cache. @@ -382,16 +379,16 @@ test "disk cache clear" { const alloc = testing.allocator; // Create our path - var td: TempDir = try .init(); - defer td.deinit(); + var tmp = testing.tmpDir(.{}); + defer tmp.cleanup(); var buf: [4096]u8 = undefined; { - var file = try td.dir.createFile("cache", .{}); + var file = try tmp.dir.createFile("cache", .{}); defer file.close(); var file_writer = file.writer(&buf); try file_writer.interface.writeAll("HELLO!"); } - const path = try td.dir.realpathAlloc(alloc, "cache"); + const path = try tmp.dir.realpathAlloc(alloc, "cache"); defer alloc.free(path); // Setup our cache @@ -401,7 +398,7 @@ test "disk cache clear" { // Verify the file is gone try testing.expectError( error.FileNotFound, - td.dir.openFile("cache", .{}), + tmp.dir.openFile("cache", .{}), ); } @@ -410,18 +407,18 @@ test "disk cache operations" { const alloc = testing.allocator; // Create our path - var td: TempDir = try .init(); - defer td.deinit(); + var tmp = testing.tmpDir(.{}); + defer tmp.cleanup(); var buf: [4096]u8 = undefined; { - var file = try td.dir.createFile("cache", .{}); + var file = try tmp.dir.createFile("cache", .{}); defer file.close(); var file_writer = file.writer(&buf); const writer = &file_writer.interface; try writer.writeAll("HELLO!"); try writer.flush(); } - const path = try td.dir.realpathAlloc(alloc, "cache"); + const path = try tmp.dir.realpathAlloc(alloc, "cache"); defer alloc.free(path); // Setup our cache @@ -453,6 +450,32 @@ test "disk cache operations" { ); } +test "disk cache cleans up temp files" { + const testing = std.testing; + const alloc = testing.allocator; + + var tmp = testing.tmpDir(.{ .iterate = true }); + defer tmp.cleanup(); + + const tmp_path = try tmp.dir.realpathAlloc(alloc, "."); + defer alloc.free(tmp_path); + const cache_path = try std.fs.path.join(alloc, &.{ tmp_path, "cache" }); + defer alloc.free(cache_path); + + const cache: DiskCache = .{ .path = cache_path }; + try testing.expectEqual(AddResult.added, try cache.add(alloc, "example.com")); + try testing.expectEqual(AddResult.added, try cache.add(alloc, "example.org")); + + // Verify only the cache file exists and no temp files left behind + var count: usize = 0; + var iter = tmp.dir.iterate(); + while (try iter.next()) |entry| { + count += 1; + try testing.expectEqualStrings("cache", entry.name); + } + try testing.expectEqual(1, count); +} + test isValidHost { const testing = std.testing;