diff --git a/src/data/api/v1/gists.rs b/src/data/api/v1/gists.rs index ccd0192..17b1ff9 100644 --- a/src/data/api/v1/gists.rs +++ b/src/data/api/v1/gists.rs @@ -14,10 +14,11 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ -use std::path::Path; +use std::path::{Path, PathBuf}; use db_core::prelude::*; use git2::*; +use serde::{Deserialize, Serialize}; use tokio::fs; use super::*; @@ -30,14 +31,24 @@ pub struct Gist { pub repository: git2::Repository, } -pub struct CreateGist<'a>{ +pub struct CreateGist<'a> { pub owner: &'a str, pub description: Option<&'a str>, pub visibility: &'a GistVisibility, - } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct File { + pub filename: String, + pub content: String, +} impl Data { - pub async fn new_gist(&self, db: &T, msg: &CreateGist<'_>) -> ServiceResult { + pub async fn new_gist( + &self, + db: &T, + msg: &CreateGist<'_>, + ) -> ServiceResult { loop { let gist_id = get_random(32); @@ -45,7 +56,7 @@ impl Data { continue; } - let gist_path = Path::new(&self.settings.repository.root).join(&gist_id); + let gist_path = self.get_repository_path(&gist_id); if gist_path.exists() { if Repository::open(&gist_path).is_ok() { @@ -70,6 +81,84 @@ impl Data { }); } } + + pub(crate) fn get_repository_path(&self, gist_id: &str) -> PathBuf { + Path::new(&self.settings.repository.root).join(gist_id) + } + + pub async fn write_file( + &self, + _db: &T, + gist_id: &str, + files: &[File], + ) -> ServiceResult<()> { + // TODO change updated in DB + + let repo = git2::Repository::open(self.get_repository_path(gist_id)).unwrap(); + let mut tree_builder = repo.treebuilder(None).unwrap(); + let odb = repo.odb().unwrap(); + + for file in files.iter() { + let escaped_filename = escape_spaces(&file.filename); + + let obj = odb + .write(ObjectType::Blob, file.content.as_bytes()) + .unwrap(); + tree_builder + .insert(&escaped_filename, obj, 0o100644) + .unwrap(); + } + + let tree_hash = tree_builder.write().unwrap(); + let author = Signature::now("gists", "admin@gists.batsense.net").unwrap(); + let committer = Signature::now("gists", "admin@gists.batsense.net").unwrap(); + + let commit_tree = repo.find_tree(tree_hash).unwrap(); + let msg = ""; + if let Err(e) = repo.head() { + if e.code() == ErrorCode::UnbornBranch && e.class() == ErrorClass::Reference { + // fisrt commit ever; set parent commit(s) to empty array + repo.commit(Some("HEAD"), &author, &committer, msg, &commit_tree, &[]) + .unwrap(); + } else { + panic!("{:?}", e); + } + } else { + let head_ref = repo.head().unwrap(); + let head_commit = head_ref.peel_to_commit().unwrap(); + repo.commit( + Some("HEAD"), + &author, + &committer, + msg, + &commit_tree, + &[&head_commit], + ) + .unwrap(); + }; + + Ok(()) + } + + /// Please note that this method expects path to not contain any spaces + /// Use [escape_spaces] before calling this method + /// + /// For example, a read request for "foo bar.md" will fail even if that file is present + /// in the repository. However, it will succeed if the output of [escape_spaces] is + /// used in the request. + pub async fn read_file( + &self, + _db: &T, + gist_id: &str, + path: &str, + ) -> ServiceResult> { + let repo = git2::Repository::open(self.get_repository_path(gist_id)).unwrap(); + let head = repo.head().unwrap(); + let tree = head.peel_to_tree().unwrap(); + let entry = tree.get_path(Path::new(path)).unwrap(); + let blob = repo.find_blob(entry.id()).unwrap(); + Ok(blob.content().to_vec()) + } } #[cfg(test)] @@ -85,30 +174,51 @@ mod tests { ]; for (db, data) in config.iter() { - const NAME: &str = "gisttestuser"; const EMAIL: &str = "gisttestuser@sss.com"; const PASSWORD: &str = "longpassword2"; - let _ = futures::join!( - data.delete_user(db, NAME, PASSWORD), - ); + let _ = futures::join!(data.delete_user(db, NAME, PASSWORD),); let _ = data.register_and_signin(db, NAME, EMAIL, PASSWORD).await; - let create_gist_msg = CreateGist { owner: NAME, description: None, visibility: &GistVisibility::Public, }; let gist = data.new_gist(db, &create_gist_msg).await.unwrap(); - let path = Path::new(&data.settings.repository.root).join(&gist.id); + let path = data.get_repository_path(&gist.id); assert!(path.exists()); assert!(db.gist_exists(&gist.id).await.unwrap()); let repo = Repository::open(&path).unwrap(); assert!(repo.is_bare()); assert!(repo.is_empty().unwrap()); + + // save files + let files = [ + File { + filename: "foo".into(), + content: "foobar".into(), + }, + File { + filename: "bar".into(), + content: "foobar".into(), + }, + File { + filename: "foo bar".into(), + content: "foobar".into(), + }, + ]; + + data.write_file(db, &gist.id, &files).await.unwrap(); + for file in files.iter() { + let content = data + .read_file(db, &gist.id, &escape_spaces(&file.filename)) + .await + .unwrap(); + assert_eq!(String::from_utf8_lossy(&content), file.content); + } } } } diff --git a/src/utils.rs b/src/utils.rs index e290a7f..328ae28 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -14,6 +14,12 @@ * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ +use std::path::Path; + +use tokio::fs; + +use crate::errors::*; + /// Get random string of specific length pub(crate) fn get_random(len: usize) -> String { use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; @@ -27,3 +33,29 @@ pub(crate) fn get_random(len: usize) -> String { .take(len) .collect::() } + +pub async fn create_dir_all_if_not_exists(path: &Path) -> ServiceResult<()> { + if !path.exists() { + fs::create_dir_all(&path).await?; + } + Ok(()) +} + +pub fn escape_spaces(name: &str) -> String { + if name.contains(' ') { + name.replace(' ', "\\ ") + } else { + name.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn space_escape() { + let space = "do re mi"; + assert_eq!(&escape_spaces(space), ("do\\ re\\ mi")); + } +}