diff --git a/src/data/api/v1/gists.rs b/src/data/api/v1/gists.rs
index ccd0192..17b1ff9 100644
--- a/src/data/api/v1/gists.rs
+++ b/src/data/api/v1/gists.rs
@@ -14,10 +14,11 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
-use std::path::Path;
+use std::path::{Path, PathBuf};
use db_core::prelude::*;
use git2::*;
+use serde::{Deserialize, Serialize};
use tokio::fs;
use super::*;
@@ -30,14 +31,24 @@ pub struct Gist {
pub repository: git2::Repository,
}
-pub struct CreateGist<'a>{
+pub struct CreateGist<'a> {
pub owner: &'a str,
pub description: Option<&'a str>,
pub visibility: &'a GistVisibility,
- }
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct File {
+ pub filename: String,
+ pub content: String,
+}
impl Data {
- pub async fn new_gist(&self, db: &T, msg: &CreateGist<'_>) -> ServiceResult {
+ pub async fn new_gist(
+ &self,
+ db: &T,
+ msg: &CreateGist<'_>,
+ ) -> ServiceResult {
loop {
let gist_id = get_random(32);
@@ -45,7 +56,7 @@ impl Data {
continue;
}
- let gist_path = Path::new(&self.settings.repository.root).join(&gist_id);
+ let gist_path = self.get_repository_path(&gist_id);
if gist_path.exists() {
if Repository::open(&gist_path).is_ok() {
@@ -70,6 +81,84 @@ impl Data {
});
}
}
+
+ pub(crate) fn get_repository_path(&self, gist_id: &str) -> PathBuf {
+ Path::new(&self.settings.repository.root).join(gist_id)
+ }
+
+ pub async fn write_file(
+ &self,
+ _db: &T,
+ gist_id: &str,
+ files: &[File],
+ ) -> ServiceResult<()> {
+ // TODO change updated in DB
+
+ let repo = git2::Repository::open(self.get_repository_path(gist_id)).unwrap();
+ let mut tree_builder = repo.treebuilder(None).unwrap();
+ let odb = repo.odb().unwrap();
+
+ for file in files.iter() {
+ let escaped_filename = escape_spaces(&file.filename);
+
+ let obj = odb
+ .write(ObjectType::Blob, file.content.as_bytes())
+ .unwrap();
+ tree_builder
+ .insert(&escaped_filename, obj, 0o100644)
+ .unwrap();
+ }
+
+ let tree_hash = tree_builder.write().unwrap();
+ let author = Signature::now("gists", "admin@gists.batsense.net").unwrap();
+ let committer = Signature::now("gists", "admin@gists.batsense.net").unwrap();
+
+ let commit_tree = repo.find_tree(tree_hash).unwrap();
+ let msg = "";
+ if let Err(e) = repo.head() {
+ if e.code() == ErrorCode::UnbornBranch && e.class() == ErrorClass::Reference {
+ // fisrt commit ever; set parent commit(s) to empty array
+ repo.commit(Some("HEAD"), &author, &committer, msg, &commit_tree, &[])
+ .unwrap();
+ } else {
+ panic!("{:?}", e);
+ }
+ } else {
+ let head_ref = repo.head().unwrap();
+ let head_commit = head_ref.peel_to_commit().unwrap();
+ repo.commit(
+ Some("HEAD"),
+ &author,
+ &committer,
+ msg,
+ &commit_tree,
+ &[&head_commit],
+ )
+ .unwrap();
+ };
+
+ Ok(())
+ }
+
+ /// Please note that this method expects path to not contain any spaces
+ /// Use [escape_spaces] before calling this method
+ ///
+ /// For example, a read request for "foo bar.md" will fail even if that file is present
+ /// in the repository. However, it will succeed if the output of [escape_spaces] is
+ /// used in the request.
+ pub async fn read_file(
+ &self,
+ _db: &T,
+ gist_id: &str,
+ path: &str,
+ ) -> ServiceResult> {
+ let repo = git2::Repository::open(self.get_repository_path(gist_id)).unwrap();
+ let head = repo.head().unwrap();
+ let tree = head.peel_to_tree().unwrap();
+ let entry = tree.get_path(Path::new(path)).unwrap();
+ let blob = repo.find_blob(entry.id()).unwrap();
+ Ok(blob.content().to_vec())
+ }
}
#[cfg(test)]
@@ -85,30 +174,51 @@ mod tests {
];
for (db, data) in config.iter() {
-
const NAME: &str = "gisttestuser";
const EMAIL: &str = "gisttestuser@sss.com";
const PASSWORD: &str = "longpassword2";
- let _ = futures::join!(
- data.delete_user(db, NAME, PASSWORD),
- );
+ let _ = futures::join!(data.delete_user(db, NAME, PASSWORD),);
let _ = data.register_and_signin(db, NAME, EMAIL, PASSWORD).await;
-
let create_gist_msg = CreateGist {
owner: NAME,
description: None,
visibility: &GistVisibility::Public,
};
let gist = data.new_gist(db, &create_gist_msg).await.unwrap();
- let path = Path::new(&data.settings.repository.root).join(&gist.id);
+ let path = data.get_repository_path(&gist.id);
assert!(path.exists());
assert!(db.gist_exists(&gist.id).await.unwrap());
let repo = Repository::open(&path).unwrap();
assert!(repo.is_bare());
assert!(repo.is_empty().unwrap());
+
+ // save files
+ let files = [
+ File {
+ filename: "foo".into(),
+ content: "foobar".into(),
+ },
+ File {
+ filename: "bar".into(),
+ content: "foobar".into(),
+ },
+ File {
+ filename: "foo bar".into(),
+ content: "foobar".into(),
+ },
+ ];
+
+ data.write_file(db, &gist.id, &files).await.unwrap();
+ for file in files.iter() {
+ let content = data
+ .read_file(db, &gist.id, &escape_spaces(&file.filename))
+ .await
+ .unwrap();
+ assert_eq!(String::from_utf8_lossy(&content), file.content);
+ }
}
}
}
diff --git a/src/utils.rs b/src/utils.rs
index e290a7f..328ae28 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -14,6 +14,12 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
+use std::path::Path;
+
+use tokio::fs;
+
+use crate::errors::*;
+
/// Get random string of specific length
pub(crate) fn get_random(len: usize) -> String {
use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng};
@@ -27,3 +33,29 @@ pub(crate) fn get_random(len: usize) -> String {
.take(len)
.collect::()
}
+
+pub async fn create_dir_all_if_not_exists(path: &Path) -> ServiceResult<()> {
+ if !path.exists() {
+ fs::create_dir_all(&path).await?;
+ }
+ Ok(())
+}
+
+pub fn escape_spaces(name: &str) -> String {
+ if name.contains(' ') {
+ name.replace(' ', "\\ ")
+ } else {
+ name.to_string()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn space_escape() {
+ let space = "do re mi";
+ assert_eq!(&escape_spaces(space), ("do\\ re\\ mi"));
+ }
+}